From 29fc0c6d283a7ae644322984c2049155fb9fb961 Mon Sep 17 00:00:00 2001 From: Khalil Meftah Date: Thu, 7 May 2026 12:09:23 +0200 Subject: [PATCH] refactor: replace build_algorithm with make_algorithm for SACAlgorithmConfig and update related tests --- src/lerobot/rl/algorithms/configs.py | 14 +---- src/lerobot/rl/algorithms/factory.py | 54 ++++++++++++++++++- .../rl/algorithms/sac/configuration_sac.py | 18 ------- tests/rl/test_sac_algorithm.py | 16 ++---- 4 files changed, 58 insertions(+), 44 deletions(-) diff --git a/src/lerobot/rl/algorithms/configs.py b/src/lerobot/rl/algorithms/configs.py index c87042bc3..f0a429be8 100644 --- a/src/lerobot/rl/algorithms/configs.py +++ b/src/lerobot/rl/algorithms/configs.py @@ -16,13 +16,9 @@ from __future__ import annotations import abc from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any +from typing import Any import draccus -import torch - -if TYPE_CHECKING: - from .base import RLAlgorithm @dataclass @@ -58,14 +54,6 @@ class RLAlgorithmConfig(draccus.ChoiceRegistry, abc.ABC): raise TypeError(f"Expected string from get_choice_name, got {type(choice_name)}") return choice_name - @abc.abstractmethod - def build_algorithm(self, policy: torch.nn.Module) -> RLAlgorithm: - """Construct the :class:`RLAlgorithm` for this config. - - Must be overridden by every registered config subclass. - """ - raise NotImplementedError(f"{type(self).__name__} must implement build_algorithm()") - @classmethod @abc.abstractmethod def from_policy_config(cls, policy_cfg: Any) -> RLAlgorithmConfig: diff --git a/src/lerobot/rl/algorithms/factory.py b/src/lerobot/rl/algorithms/factory.py index 8adc9883d..2a5d9dea7 100644 --- a/src/lerobot/rl/algorithms/factory.py +++ b/src/lerobot/rl/algorithms/factory.py @@ -43,5 +43,57 @@ def make_algorithm_config(algorithm_type: str, **kwargs) -> RLAlgorithmConfig: return cls(**kwargs) +def get_algorithm_class(name: str) -> type[RLAlgorithm]: + """ + Retrieves an RL algorithm class by its registered name. + + This function uses dynamic imports to avoid loading all algorithm classes into + memory at once, improving startup time and reducing dependencies. + + Args: + name: The name of the algorithm. Supported names are "sac". + + Returns: + The algorithm class corresponding to the given name. + + Raises: + ValueError: If the algorithm name is not recognized. + """ + if name == "sac": + from .sac.sac_algorithm import SACAlgorithm + + return SACAlgorithm + raise ValueError( + f"Algorithm type '{name}' is not available. " + f"Known: {list(RLAlgorithmConfig.get_known_choices().keys())}" + ) + + def make_algorithm(cfg: RLAlgorithmConfig, policy: torch.nn.Module) -> RLAlgorithm: - return cfg.build_algorithm(policy) + """ + Instantiate an RL algorithm. + + This factory function looks up the :class:`RLAlgorithm` subclass that matches + ``cfg.type`` and instantiates it with the provided policy. It also enforces + that ``cfg.policy_config`` has been populated before construction (this is + normally handled by :meth:`TrainRLServerPipelineConfig.validate`). + + Args: + cfg: The algorithm configuration. Must have ``policy_config`` set. + policy: The policy module the algorithm will train. + + Returns: + An instantiated :class:`RLAlgorithm`. + + Raises: + ValueError: If ``cfg.policy_config`` is ``None`` or ``cfg.type`` is not + registered. + """ + if getattr(cfg, "policy_config", None) is None: + raise ValueError( + f"{type(cfg).__name__}.policy_config is None. " + "It must be populated (typically by TrainRLServerPipelineConfig.validate) " + "before calling make_algorithm()." + ) + cls = get_algorithm_class(cfg.type) + return cls(policy=policy, config=cfg) diff --git a/src/lerobot/rl/algorithms/sac/configuration_sac.py b/src/lerobot/rl/algorithms/sac/configuration_sac.py index d126925ff..28e024fe0 100644 --- a/src/lerobot/rl/algorithms/sac/configuration_sac.py +++ b/src/lerobot/rl/algorithms/sac/configuration_sac.py @@ -15,9 +15,6 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import TYPE_CHECKING - -import torch from lerobot.policies.gaussian_actor.configuration_gaussian_actor import ( CriticNetworkConfig, @@ -26,9 +23,6 @@ from lerobot.policies.gaussian_actor.configuration_gaussian_actor import ( from ..configs import RLAlgorithmConfig -if TYPE_CHECKING: - from .sac_algorithm import SACAlgorithm - @RLAlgorithmConfig.register_subclass("sac") @dataclass @@ -102,15 +96,3 @@ class SACAlgorithmConfig(RLAlgorithmConfig): policy_config=policy_cfg, discrete_critic_network_kwargs=policy_cfg.discrete_critic_network_kwargs, ) - - def build_algorithm(self, policy: torch.nn.Module) -> SACAlgorithm: - if self.policy_config is None: - raise ValueError( - "SACAlgorithmConfig.policy_config is None. " - "It must be populated (typically by TrainRLServerPipelineConfig.validate) " - "before calling build_algorithm()." - ) - - from .sac_algorithm import SACAlgorithm - - return SACAlgorithm(policy=policy, config=self) diff --git a/tests/rl/test_sac_algorithm.py b/tests/rl/test_sac_algorithm.py index ffecd17b8..e2a6298ff 100644 --- a/tests/rl/test_sac_algorithm.py +++ b/tests/rl/test_sac_algorithm.py @@ -501,25 +501,17 @@ def test_training_stats_generic_losses(): # =========================================================================== -# Registry-driven build_algorithm +# Registry-driven make_algorithm # =========================================================================== -def test_build_algorithm_via_config(): - """SACAlgorithmConfig.build_algorithm should produce a working SACAlgorithm.""" +def test_make_algorithm_builds_sac(): + """make_algorithm should look up the SAC class from the registry and instantiate it.""" sac_cfg = _make_sac_config() algo_config = SACAlgorithmConfig.from_policy_config(sac_cfg) algo_config.utd_ratio = 2 policy = GaussianActorPolicy(config=sac_cfg) - algorithm = algo_config.build_algorithm(policy) + algorithm = make_algorithm(cfg=algo_config, policy=policy) assert isinstance(algorithm, SACAlgorithm) assert algorithm.config.utd_ratio == 2 - - -def test_make_algorithm_uses_build_algorithm(): - """make_algorithm should delegate to config.build_algorithm (no hardcoded if/else).""" - sac_cfg = _make_sac_config() - policy = GaussianActorPolicy(config=sac_cfg) - algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy) - assert isinstance(algorithm, SACAlgorithm)