refactor: replace build_algorithm with make_algorithm for SACAlgorithmConfig and update related tests

This commit is contained in:
Khalil Meftah
2026-05-07 12:09:23 +02:00
parent f1bdd6744f
commit 29fc0c6d28
4 changed files with 58 additions and 44 deletions
+1 -13
View File
@@ -16,13 +16,9 @@ from __future__ import annotations
import abc
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from typing import Any
import draccus
import torch
if TYPE_CHECKING:
from .base import RLAlgorithm
@dataclass
@@ -58,14 +54,6 @@ class RLAlgorithmConfig(draccus.ChoiceRegistry, abc.ABC):
raise TypeError(f"Expected string from get_choice_name, got {type(choice_name)}")
return choice_name
@abc.abstractmethod
def build_algorithm(self, policy: torch.nn.Module) -> RLAlgorithm:
"""Construct the :class:`RLAlgorithm` for this config.
Must be overridden by every registered config subclass.
"""
raise NotImplementedError(f"{type(self).__name__} must implement build_algorithm()")
@classmethod
@abc.abstractmethod
def from_policy_config(cls, policy_cfg: Any) -> RLAlgorithmConfig:
+53 -1
View File
@@ -43,5 +43,57 @@ def make_algorithm_config(algorithm_type: str, **kwargs) -> RLAlgorithmConfig:
return cls(**kwargs)
def get_algorithm_class(name: str) -> type[RLAlgorithm]:
"""
Retrieves an RL algorithm class by its registered name.
This function uses dynamic imports to avoid loading all algorithm classes into
memory at once, improving startup time and reducing dependencies.
Args:
name: The name of the algorithm. Supported names are "sac".
Returns:
The algorithm class corresponding to the given name.
Raises:
ValueError: If the algorithm name is not recognized.
"""
if name == "sac":
from .sac.sac_algorithm import SACAlgorithm
return SACAlgorithm
raise ValueError(
f"Algorithm type '{name}' is not available. "
f"Known: {list(RLAlgorithmConfig.get_known_choices().keys())}"
)
def make_algorithm(cfg: RLAlgorithmConfig, policy: torch.nn.Module) -> RLAlgorithm:
return cfg.build_algorithm(policy)
"""
Instantiate an RL algorithm.
This factory function looks up the :class:`RLAlgorithm` subclass that matches
``cfg.type`` and instantiates it with the provided policy. It also enforces
that ``cfg.policy_config`` has been populated before construction (this is
normally handled by :meth:`TrainRLServerPipelineConfig.validate`).
Args:
cfg: The algorithm configuration. Must have ``policy_config`` set.
policy: The policy module the algorithm will train.
Returns:
An instantiated :class:`RLAlgorithm`.
Raises:
ValueError: If ``cfg.policy_config`` is ``None`` or ``cfg.type`` is not
registered.
"""
if getattr(cfg, "policy_config", None) is None:
raise ValueError(
f"{type(cfg).__name__}.policy_config is None. "
"It must be populated (typically by TrainRLServerPipelineConfig.validate) "
"before calling make_algorithm()."
)
cls = get_algorithm_class(cfg.type)
return cls(policy=policy, config=cfg)
@@ -15,9 +15,6 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
import torch
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import (
CriticNetworkConfig,
@@ -26,9 +23,6 @@ from lerobot.policies.gaussian_actor.configuration_gaussian_actor import (
from ..configs import RLAlgorithmConfig
if TYPE_CHECKING:
from .sac_algorithm import SACAlgorithm
@RLAlgorithmConfig.register_subclass("sac")
@dataclass
@@ -102,15 +96,3 @@ class SACAlgorithmConfig(RLAlgorithmConfig):
policy_config=policy_cfg,
discrete_critic_network_kwargs=policy_cfg.discrete_critic_network_kwargs,
)
def build_algorithm(self, policy: torch.nn.Module) -> SACAlgorithm:
if self.policy_config is None:
raise ValueError(
"SACAlgorithmConfig.policy_config is None. "
"It must be populated (typically by TrainRLServerPipelineConfig.validate) "
"before calling build_algorithm()."
)
from .sac_algorithm import SACAlgorithm
return SACAlgorithm(policy=policy, config=self)
+4 -12
View File
@@ -501,25 +501,17 @@ def test_training_stats_generic_losses():
# ===========================================================================
# Registry-driven build_algorithm
# Registry-driven make_algorithm
# ===========================================================================
def test_build_algorithm_via_config():
"""SACAlgorithmConfig.build_algorithm should produce a working SACAlgorithm."""
def test_make_algorithm_builds_sac():
"""make_algorithm should look up the SAC class from the registry and instantiate it."""
sac_cfg = _make_sac_config()
algo_config = SACAlgorithmConfig.from_policy_config(sac_cfg)
algo_config.utd_ratio = 2
policy = GaussianActorPolicy(config=sac_cfg)
algorithm = algo_config.build_algorithm(policy)
algorithm = make_algorithm(cfg=algo_config, policy=policy)
assert isinstance(algorithm, SACAlgorithm)
assert algorithm.config.utd_ratio == 2
def test_make_algorithm_uses_build_algorithm():
"""make_algorithm should delegate to config.build_algorithm (no hardcoded if/else)."""
sac_cfg = _make_sac_config()
policy = GaussianActorPolicy(config=sac_cfg)
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
assert isinstance(algorithm, SACAlgorithm)