diff --git a/src/lerobot/rl/algorithms/base.py b/src/lerobot/rl/algorithms/base.py index 0add15597..01c34584b 100644 --- a/src/lerobot/rl/algorithms/base.py +++ b/src/lerobot/rl/algorithms/base.py @@ -15,25 +15,38 @@ from __future__ import annotations import abc +import builtins +import os from collections.abc import Iterator -from typing import TYPE_CHECKING, Any +from pathlib import Path +from typing import TYPE_CHECKING, Any, TypeVar import torch +from huggingface_hub import hf_hub_download +from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE +from huggingface_hub.errors import HfHubHTTPError +from safetensors.torch import load_file as load_safetensors, save_file as save_safetensors from torch.optim import Optimizer from lerobot.types import BatchType +from lerobot.utils.hub import HubMixin from .configs import RLAlgorithmConfig, TrainingStats if TYPE_CHECKING: + from torch import nn + from ..data_sources.data_mixer import DataMixer +T = TypeVar("T", bound="RLAlgorithm") -class RLAlgorithm(abc.ABC): + +class RLAlgorithm(HubMixin, abc.ABC): """Base for all RL algorithms.""" config_class: type[RLAlgorithmConfig] name: str + config: RLAlgorithmConfig @abc.abstractmethod def update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats: @@ -98,3 +111,97 @@ class RLAlgorithm(abc.ABC): def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None: """Load policy state-dict received from the learner.""" raise NotImplementedError + + @abc.abstractmethod + def state_dict(self) -> dict[str, torch.Tensor]: + """Algorithm-owned trainable tensors. + + Must return a flat tensor mapping for everything the algorithm owns + that is not part of the policy (e.g. critic ensembles, target networks, + temperature parameters). Algorithms with no training-only tensors + should explicitly return an empty dict. + """ + raise NotImplementedError + + @abc.abstractmethod + def load_state_dict( + self, + state_dict: dict[str, torch.Tensor], + device: str | torch.device = "cpu", + ) -> None: + """In-place load of algorithm-owned tensors. + + Implementations MUST keep the identity of any ``nn.Parameter`` that an + optimizer references (e.g. SAC's ``log_alpha``) by using ``.copy_()`` + rather than rebinding the attribute. + """ + raise NotImplementedError + + def _save_pretrained(self, save_directory: Path) -> None: + """Persist the algorithm's tensors and config to ``save_directory``. + + Writes ``model.safetensors`` (algorithm tensors via :meth:`state_dict`) + and ``config.json`` (via :meth:`RLAlgorithmConfig.save_pretrained`). + """ + tensors = {k: v.detach().cpu().contiguous() for k, v in self.state_dict().items()} + save_safetensors(tensors, str(save_directory / SAFETENSORS_SINGLE_FILE)) + self.config._save_pretrained(save_directory) + + @classmethod + def from_pretrained( + cls: builtins.type[T], + pretrained_name_or_path: str | Path, + *, + policy: nn.Module, + config: RLAlgorithmConfig | None = None, + force_download: bool = False, + resume_download: bool | None = None, + proxies: dict | None = None, + token: str | bool | None = None, + cache_dir: str | Path | None = None, + local_files_only: bool = False, + revision: str | None = None, + device: str | torch.device = "cpu", + **algo_kwargs: Any, + ) -> T: + """Build an algorithm and load its weights from ``pretrained_name_or_path``.""" + if config is None: + config = cls.config_class.from_pretrained( + pretrained_name_or_path, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + revision=revision, + ) + if hasattr(config, "policy_config"): + config.policy_config = policy.config + + instance = cls(policy=policy, config=config, **algo_kwargs) + + model_id = str(pretrained_name_or_path) + if os.path.isdir(model_id): + model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE) + else: + try: + model_file = hf_hub_download( + repo_id=model_id, + filename=SAFETENSORS_SINGLE_FILE, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + token=token, + local_files_only=local_files_only, + ) + except HfHubHTTPError as e: + raise FileNotFoundError( + f"{SAFETENSORS_SINGLE_FILE} not found on the HuggingFace Hub in {model_id}" + ) from e + + tensors = load_safetensors(model_file) + instance.load_state_dict(tensors, device=device) + return instance diff --git a/src/lerobot/rl/algorithms/configs.py b/src/lerobot/rl/algorithms/configs.py index f0a429be8..9448afeb3 100644 --- a/src/lerobot/rl/algorithms/configs.py +++ b/src/lerobot/rl/algorithms/configs.py @@ -15,10 +15,23 @@ from __future__ import annotations import abc +import builtins +import logging +import os from dataclasses import dataclass, field -from typing import Any +from pathlib import Path +from typing import Any, TypeVar import draccus +from huggingface_hub import hf_hub_download +from huggingface_hub.constants import CONFIG_NAME +from huggingface_hub.errors import HfHubHTTPError + +from lerobot.utils.hub import HubMixin + +T = TypeVar("T", bound="RLAlgorithmConfig") + +logger = logging.getLogger(__name__) @dataclass @@ -43,7 +56,7 @@ class TrainingStats: @dataclass -class RLAlgorithmConfig(draccus.ChoiceRegistry, abc.ABC): +class RLAlgorithmConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): """Registry for algorithm configs.""" @property @@ -62,3 +75,64 @@ class RLAlgorithmConfig(draccus.ChoiceRegistry, abc.ABC): Must be overridden by every registered config subclass. """ raise NotImplementedError(f"{cls.__name__} must implement from_policy_config()") + + def _save_pretrained(self, save_directory: Path) -> None: + """Serialize this config as ``config.json`` inside ``save_directory``.""" + with open(save_directory / CONFIG_NAME, "w") as f, draccus.config_type("json"): + draccus.dump(self, f, indent=4) + + @classmethod + def from_pretrained( + cls: builtins.type[T], + pretrained_name_or_path: str | Path, + *, + force_download: bool = False, + resume_download: bool | None = None, + proxies: dict[Any, Any] | None = None, + token: str | bool | None = None, + cache_dir: str | Path | None = None, + local_files_only: bool = False, + revision: str | None = None, + **algo_kwargs: Any, + ) -> T: + model_id = str(pretrained_name_or_path) + config_file: str | None = None + if Path(model_id).is_dir(): + if CONFIG_NAME in os.listdir(model_id): + config_file = os.path.join(model_id, CONFIG_NAME) + else: + logger.error(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}") + else: + try: + config_file = hf_hub_download( + repo_id=model_id, + filename=CONFIG_NAME, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + token=token, + local_files_only=local_files_only, + ) + except HfHubHTTPError as e: + raise FileNotFoundError( + f"{CONFIG_NAME} not found on the HuggingFace Hub in {model_id}" + ) from e + + if config_file is None: + raise FileNotFoundError(f"{CONFIG_NAME} not found in {model_id}") + + with draccus.config_type("json"): + instance = draccus.parse(RLAlgorithmConfig, config_file, args=[]) + + if cls is not RLAlgorithmConfig and not isinstance(instance, cls): + raise TypeError( + f"Config at {model_id} has type '{instance.type}' but was loaded via " + f"{cls.__name__}; use the matching subclass or RLAlgorithmConfig.from_pretrained()." + ) + + for key, value in algo_kwargs.items(): + if hasattr(instance, key): + setattr(instance, key, value) + return instance diff --git a/src/lerobot/rl/algorithms/sac/configuration_sac.py b/src/lerobot/rl/algorithms/sac/configuration_sac.py index 28e024fe0..c4e9b334a 100644 --- a/src/lerobot/rl/algorithms/sac/configuration_sac.py +++ b/src/lerobot/rl/algorithms/sac/configuration_sac.py @@ -16,6 +16,7 @@ from __future__ import annotations from dataclasses import dataclass, field +from lerobot.configs.policies import PreTrainedConfig from lerobot.policies.gaussian_actor.configuration_gaussian_actor import ( CriticNetworkConfig, GaussianActorConfig, @@ -87,7 +88,7 @@ class SACAlgorithmConfig(RLAlgorithmConfig): use_torch_compile: bool = False # Policy config - policy_config: GaussianActorConfig | None = None + policy_config: PreTrainedConfig | None = None @classmethod def from_policy_config(cls, policy_cfg: GaussianActorConfig) -> SACAlgorithmConfig: diff --git a/src/lerobot/rl/algorithms/sac/sac_algorithm.py b/src/lerobot/rl/algorithms/sac/sac_algorithm.py index 49640eb1a..eeb9f1fc5 100644 --- a/src/lerobot/rl/algorithms/sac/sac_algorithm.py +++ b/src/lerobot/rl/algorithms/sac/sac_algorithm.py @@ -513,6 +513,46 @@ class SACAlgorithm(RLAlgorithm): """Load actor + discrete-critic weights into the policy.""" self.policy.load_actor_weights(weights, device=device) + def state_dict(self) -> dict[str, torch.Tensor]: + """Algorithm-owned trainable tensors. + + Encoder weights are stripped because they are owned by the policy + (``policy.encoder_critic``) and already saved via ``policy.save_pretrained``. + """ + bundle: dict[str, torch.Tensor] = {} + for k, v in _strip_encoder_keys(self.critic_ensemble.state_dict()).items(): + bundle[f"critic_ensemble.{k}"] = v + for k, v in _strip_encoder_keys(self.critic_target.state_dict()).items(): + bundle[f"critic_target.{k}"] = v + if self.discrete_critic_target is not None: + for k, v in _strip_encoder_keys(self.discrete_critic_target.state_dict()).items(): + bundle[f"discrete_critic_target.{k}"] = v + bundle["log_alpha"] = self.log_alpha.detach() + return bundle + + def load_state_dict( + self, + state_dict: dict[str, torch.Tensor], + device: str | torch.device = "cpu", + ) -> None: + """In-place load of algorithm-owned tensors. + + ``log_alpha`` is restored via ``Parameter.data.copy_`` so the + ``temperature`` optimizer's reference to the parameter object stays + valid after resume. + """ + critic_ensemble_state = _split_prefix(state_dict, "critic_ensemble.") + critic_target_state = _split_prefix(state_dict, "critic_target.") + self.critic_ensemble.load_state_dict(critic_ensemble_state, strict=False) + self.critic_target.load_state_dict(critic_target_state, strict=False) + + if self.discrete_critic_target is not None: + discrete_target_state = _split_prefix(state_dict, "discrete_critic_target.") + self.discrete_critic_target.load_state_dict(discrete_target_state, strict=False) + + if "log_alpha" in state_dict: + self.log_alpha.data.copy_(state_dict["log_alpha"].to(self.log_alpha.device)) + def get_observation_features( self, observations: Tensor, next_observations: Tensor ) -> tuple[Tensor | None, Tensor | None]: @@ -540,6 +580,16 @@ class SACAlgorithm(RLAlgorithm): return observation_features, next_observation_features +def _strip_encoder_keys(state: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """Drop ``encoder.*`` keys from a critic-module state dict.""" + return {k: v for k, v in state.items() if not k.startswith("encoder.")} + + +def _split_prefix(state: dict[str, torch.Tensor], prefix: str) -> dict[str, torch.Tensor]: + """Return the subset of ``state`` whose keys start with ``prefix``, prefix-stripped.""" + return {k.removeprefix(prefix): v for k, v in state.items() if k.startswith(prefix)} + + class CriticHead(nn.Module): def __init__( self, diff --git a/src/lerobot/rl/learner.py b/src/lerobot/rl/learner.py index 3caa07387..f41d9d602 100644 --- a/src/lerobot/rl/learner.py +++ b/src/lerobot/rl/learner.py @@ -56,6 +56,8 @@ from pprint import pformat from typing import TYPE_CHECKING, Any import torch +from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE +from safetensors.torch import load_file as load_safetensors from termcolor import colored from torch import nn from torch.multiprocessing import Queue @@ -77,13 +79,16 @@ from lerobot.teleoperators import gamepad, so_leader # noqa: F401 from lerobot.teleoperators.utils import TeleopEvents from lerobot.utils.constants import ( ACTION, + ALGORITHM_DIR, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK, PRETRAINED_MODEL_DIR, TRAINING_STATE_DIR, + TRAINING_STEP, ) from lerobot.utils.device_utils import get_safe_torch_device from lerobot.utils.import_utils import _grpc_available, require_package +from lerobot.utils.io_utils import load_json, write_json from lerobot.utils.process import ProcessSignalHandler from lerobot.utils.random_utils import set_seed from lerobot.utils.utils import ( @@ -370,7 +375,9 @@ def add_actor_information_and_train( # If we are resuming, we need to load the training state optimizers = algorithm.get_optimizers() - resume_optimization_step, resume_interaction_step = load_training_state(cfg=cfg, optimizers=optimizers) + resume_optimization_step, resume_interaction_step = load_training_state( + cfg=cfg, optimizers=optimizers, algorithm=algorithm, device=device + ) logging.info("Starting learner thread") interaction_message = None @@ -464,6 +471,7 @@ def add_actor_information_and_train( policy=policy, optimizers=optimizers, replay_buffer=replay_buffer, + algorithm=algorithm, offline_replay_buffer=offline_replay_buffer, dataset_repo_id=dataset_repo_id, fps=fps, @@ -550,6 +558,7 @@ def save_training_checkpoint( policy: nn.Module, optimizers: dict[str, Optimizer], replay_buffer: ReplayBuffer, + algorithm: RLAlgorithm | None = None, offline_replay_buffer: ReplayBuffer | None = None, dataset_repo_id: str | None = None, fps: int = 30, @@ -588,7 +597,7 @@ def save_training_checkpoint( # Create checkpoint directory checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, online_steps, optimization_step) - # Save checkpoint + # Save policy artifacts (pretrained_model/) + Trainer scaffolding (training_state/). save_checkpoint( checkpoint_dir=checkpoint_dir, step=optimization_step, @@ -600,11 +609,18 @@ def save_training_checkpoint( postprocessor=postprocessor, ) - # Save interaction step manually - training_state_dir = os.path.join(checkpoint_dir, TRAINING_STATE_DIR) - os.makedirs(training_state_dir, exist_ok=True) - training_state = {"step": optimization_step, "interaction_step": interaction_step} - torch.save(training_state, os.path.join(training_state_dir, "training_state.pt")) + # Algorithm-owned tensors live in their own component subfolder + # so they can be `push_to_hub`'d independently and don't bloat the inference artifact. + if algorithm is not None: + algorithm.save_pretrained(checkpoint_dir / ALGORITHM_DIR) + + # Enrich training_step.json with the RL-specific interaction_step counter so + # both can be restored from a single file. + training_state_dir = checkpoint_dir / TRAINING_STATE_DIR + write_json( + {"step": optimization_step, "interaction_step": interaction_step}, + training_state_dir / TRAINING_STEP, + ) # Update the "last" symlink update_last_checkpoint(checkpoint_dir) @@ -698,13 +714,20 @@ def handle_resume_logic(cfg: TrainRLServerPipelineConfig) -> TrainRLServerPipeli def load_training_state( cfg: TrainRLServerPipelineConfig, optimizers: Optimizer | dict[str, Optimizer], + algorithm: RLAlgorithm | None = None, + device: str | torch.device = "cpu", ): """ - Loads the training state (optimizers, step count, etc.) from a checkpoint. + Loads the training state (optimizers, RNG, step + interaction step, and + algorithm-owned tensors) from the most recent checkpoint. Args: - cfg (TrainRLServerPipelineConfig): Training configuration - optimizers (Optimizer | dict): Optimizers to load state into + cfg: Training configuration. + optimizers: Optimizers to load state into. + algorithm: Algorithm whose state dict should be restored. + Required for full main-equivalent resume; + the policy itself is restored separately via ``make_policy``. + device: Device on which to place loaded algorithm tensors. Returns: tuple: (optimization_step, interaction_step) or (None, None) if not resuming @@ -713,20 +736,31 @@ def load_training_state( return None, None # Construct path to the last checkpoint directory - checkpoint_dir = os.path.join(cfg.output_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK) + checkpoint_dir = Path(cfg.output_dir) / CHECKPOINTS_DIR / LAST_CHECKPOINT_LINK logging.info(f"Loading training state from {checkpoint_dir}") try: - # Use the utility function from train_utils which loads the optimizer state - step, optimizers, _ = utils_load_training_state(Path(checkpoint_dir), optimizers, None) + # Restore optimizers + RNG + step from the standard `training_state/` folder + step, optimizers, _ = utils_load_training_state(checkpoint_dir, optimizers, None) - # Load interaction step separately from training_state.pt - training_state_path = os.path.join(checkpoint_dir, TRAINING_STATE_DIR, "training_state.pt") - interaction_step = 0 - if os.path.exists(training_state_path): - training_state = torch.load(training_state_path, weights_only=False) # nosec B614: Safe usage of torch.load - interaction_step = training_state.get("interaction_step", 0) + # Restore algorithm-owned tensors + if algorithm is not None: + algo_dir = checkpoint_dir / ALGORITHM_DIR + if algo_dir.is_dir(): + tensors = load_safetensors(str(algo_dir / SAFETENSORS_SINGLE_FILE)) + algorithm.load_state_dict(tensors, device=device) + logging.info(f"Loaded algorithm state from {algo_dir}") + else: + logging.warning( + f"No algorithm state found at {algo_dir}; " + "will keep their freshly-initialised values. Adam moments restored from the " + "old optimizer state may not match these reset parameters." + ) + + # Read interaction_step from the enriched training_step.json + training_step_path = checkpoint_dir / TRAINING_STATE_DIR / TRAINING_STEP + interaction_step = int(load_json(training_step_path).get("interaction_step", 0)) logging.info(f"Resuming from step {step}, interaction step {interaction_step}") return step, interaction_step diff --git a/src/lerobot/utils/constants.py b/src/lerobot/utils/constants.py index 43869228d..482394ff6 100644 --- a/src/lerobot/utils/constants.py +++ b/src/lerobot/utils/constants.py @@ -47,6 +47,7 @@ CHECKPOINTS_DIR = "checkpoints" LAST_CHECKPOINT_LINK = "last" PRETRAINED_MODEL_DIR = "pretrained_model" TRAINING_STATE_DIR = "training_state" +ALGORITHM_DIR = "algorithm" RNG_STATE = "rng_state.safetensors" TRAINING_STEP = "training_step.json" OPTIMIZER_STATE = "optimizer_state.safetensors" diff --git a/tests/rl/test_sac_algorithm.py b/tests/rl/test_sac_algorithm.py index e2a6298ff..990d63164 100644 --- a/tests/rl/test_sac_algorithm.py +++ b/tests/rl/test_sac_algorithm.py @@ -515,3 +515,88 @@ def test_make_algorithm_builds_sac(): algorithm = make_algorithm(cfg=algo_config, policy=policy) assert isinstance(algorithm, SACAlgorithm) assert algorithm.config.utd_ratio == 2 + + +# =========================================================================== +# state_dict / load_state_dict (algorithm-side resume) +# =========================================================================== + + +def test_state_dict_contains_algorithm_owned_tensors(): + """state_dict should pack critics, target networks, and log_alpha (no encoder bloat).""" + algorithm, _ = _make_algorithm() + sd = algorithm.state_dict() + + assert "log_alpha" in sd + assert any(k.startswith("critic_ensemble.") for k in sd) + assert any(k.startswith("critic_target.") for k in sd) + # encoder weights live on the policy and must not be duplicated here. + assert not any(".encoder." in k for k in sd) + + +def test_state_dict_includes_discrete_critic_target_when_present(): + algorithm, _ = _make_algorithm(num_discrete_actions=3, action_dim=6) + sd = algorithm.state_dict() + assert any(k.startswith("discrete_critic_target.") for k in sd) + + +def test_load_state_dict_round_trip_restores_critics_and_log_alpha(): + """state_dict -> load_state_dict on a fresh algorithm restores all bytes exactly.""" + sac_cfg = _make_sac_config(num_discrete_actions=3, action_dim=6) + src_policy = GaussianActorPolicy(config=sac_cfg) + src = SACAlgorithm(policy=src_policy, config=SACAlgorithmConfig.from_policy_config(sac_cfg)) + src.make_optimizers_and_scheduler() + # Train a few steps so weights diverge from init (action_dim=7 = 6 continuous + 1 discrete). + src.update(_batch_iterator(action_dim=7)) + src.update(_batch_iterator(action_dim=7)) + + dst_policy = GaussianActorPolicy(config=sac_cfg) + dst = SACAlgorithm(policy=dst_policy, config=SACAlgorithmConfig.from_policy_config(sac_cfg)) + dst.make_optimizers_and_scheduler() + + src_sd = src.state_dict() + dst.load_state_dict(src_sd) + dst_sd = dst.state_dict() + + assert set(dst_sd) == set(src_sd) + for key in src_sd: + assert torch.allclose(src_sd[key].cpu(), dst_sd[key].cpu()), f"{key} mismatch after round-trip" + + +def test_load_state_dict_preserves_log_alpha_parameter_identity(): + """The temperature optimizer holds a reference to log_alpha; identity must survive load.""" + algorithm, _ = _make_algorithm() + log_alpha_id_before = id(algorithm.log_alpha) + optimizer_param_id = id(algorithm.optimizers["temperature"].param_groups[0]["params"][0]) + assert log_alpha_id_before == optimizer_param_id + + new_state = algorithm.state_dict() + new_state["log_alpha"] = torch.tensor([0.42]) + algorithm.load_state_dict(new_state) + + assert id(algorithm.log_alpha) == log_alpha_id_before + assert id(algorithm.optimizers["temperature"].param_groups[0]["params"][0]) == log_alpha_id_before + assert torch.allclose(algorithm.log_alpha.detach().cpu(), torch.tensor([0.42])) + + +def test_save_pretrained_round_trip_via_disk(tmp_path): + """End-to-end: save_pretrained -> from_pretrained restores tensors and config.""" + sac_cfg = _make_sac_config() + src_policy = GaussianActorPolicy(config=sac_cfg) + src = SACAlgorithm(policy=src_policy, config=SACAlgorithmConfig.from_policy_config(sac_cfg)) + src.make_optimizers_and_scheduler() + src.update(_batch_iterator()) + + save_dir = tmp_path / "algorithm" + src.save_pretrained(save_dir) + assert (save_dir / "model.safetensors").is_file() + assert (save_dir / "config.json").is_file() + + dst_policy = GaussianActorPolicy(config=sac_cfg) + dst = SACAlgorithm.from_pretrained(save_dir, policy=dst_policy) + + src_sd = src.state_dict() + dst_sd = dst.state_dict() + assert set(src_sd) == set(dst_sd) + for key in src_sd: + assert torch.allclose(src_sd[key].cpu(), dst_sd[key].cpu()), f"{key} mismatch after disk round-trip" diff --git a/tests/rl/test_trainer.py b/tests/rl/test_trainer.py index 2970f9bc6..b15d4393b 100644 --- a/tests/rl/test_trainer.py +++ b/tests/rl/test_trainer.py @@ -68,6 +68,12 @@ class _DummyRLAlgorithm(RLAlgorithm): def load_weights(self, weights, device="cpu") -> None: _ = (weights, device) + def state_dict(self) -> dict[str, torch.Tensor]: + return {} + + def load_state_dict(self, state_dict, device="cpu") -> None: + _ = (state_dict, device) + class _SimpleMixer: def get_iterator(self, batch_size: int, async_prefetch: bool = True, queue_size: int = 2):