mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 04:30:10 +00:00
refactor: replace build_algorithm with make_algorithm for SACAlgorithmConfig and update related tests
This commit is contained in:
@@ -16,13 +16,9 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import abc
|
import abc
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import Any
|
||||||
|
|
||||||
import draccus
|
import draccus
|
||||||
import torch
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from .base import RLAlgorithm
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -58,14 +54,6 @@ class RLAlgorithmConfig(draccus.ChoiceRegistry, abc.ABC):
|
|||||||
raise TypeError(f"Expected string from get_choice_name, got {type(choice_name)}")
|
raise TypeError(f"Expected string from get_choice_name, got {type(choice_name)}")
|
||||||
return choice_name
|
return choice_name
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def build_algorithm(self, policy: torch.nn.Module) -> RLAlgorithm:
|
|
||||||
"""Construct the :class:`RLAlgorithm` for this config.
|
|
||||||
|
|
||||||
Must be overridden by every registered config subclass.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError(f"{type(self).__name__} must implement build_algorithm()")
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def from_policy_config(cls, policy_cfg: Any) -> RLAlgorithmConfig:
|
def from_policy_config(cls, policy_cfg: Any) -> RLAlgorithmConfig:
|
||||||
|
|||||||
@@ -43,5 +43,57 @@ def make_algorithm_config(algorithm_type: str, **kwargs) -> RLAlgorithmConfig:
|
|||||||
return cls(**kwargs)
|
return cls(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def get_algorithm_class(name: str) -> type[RLAlgorithm]:
|
||||||
|
"""
|
||||||
|
Retrieves an RL algorithm class by its registered name.
|
||||||
|
|
||||||
|
This function uses dynamic imports to avoid loading all algorithm classes into
|
||||||
|
memory at once, improving startup time and reducing dependencies.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The name of the algorithm. Supported names are "sac".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The algorithm class corresponding to the given name.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the algorithm name is not recognized.
|
||||||
|
"""
|
||||||
|
if name == "sac":
|
||||||
|
from .sac.sac_algorithm import SACAlgorithm
|
||||||
|
|
||||||
|
return SACAlgorithm
|
||||||
|
raise ValueError(
|
||||||
|
f"Algorithm type '{name}' is not available. "
|
||||||
|
f"Known: {list(RLAlgorithmConfig.get_known_choices().keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def make_algorithm(cfg: RLAlgorithmConfig, policy: torch.nn.Module) -> RLAlgorithm:
|
def make_algorithm(cfg: RLAlgorithmConfig, policy: torch.nn.Module) -> RLAlgorithm:
|
||||||
return cfg.build_algorithm(policy)
|
"""
|
||||||
|
Instantiate an RL algorithm.
|
||||||
|
|
||||||
|
This factory function looks up the :class:`RLAlgorithm` subclass that matches
|
||||||
|
``cfg.type`` and instantiates it with the provided policy. It also enforces
|
||||||
|
that ``cfg.policy_config`` has been populated before construction (this is
|
||||||
|
normally handled by :meth:`TrainRLServerPipelineConfig.validate`).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg: The algorithm configuration. Must have ``policy_config`` set.
|
||||||
|
policy: The policy module the algorithm will train.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An instantiated :class:`RLAlgorithm`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If ``cfg.policy_config`` is ``None`` or ``cfg.type`` is not
|
||||||
|
registered.
|
||||||
|
"""
|
||||||
|
if getattr(cfg, "policy_config", None) is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"{type(cfg).__name__}.policy_config is None. "
|
||||||
|
"It must be populated (typically by TrainRLServerPipelineConfig.validate) "
|
||||||
|
"before calling make_algorithm()."
|
||||||
|
)
|
||||||
|
cls = get_algorithm_class(cfg.type)
|
||||||
|
return cls(policy=policy, config=cfg)
|
||||||
|
|||||||
@@ -15,9 +15,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import (
|
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import (
|
||||||
CriticNetworkConfig,
|
CriticNetworkConfig,
|
||||||
@@ -26,9 +23,6 @@ from lerobot.policies.gaussian_actor.configuration_gaussian_actor import (
|
|||||||
|
|
||||||
from ..configs import RLAlgorithmConfig
|
from ..configs import RLAlgorithmConfig
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from .sac_algorithm import SACAlgorithm
|
|
||||||
|
|
||||||
|
|
||||||
@RLAlgorithmConfig.register_subclass("sac")
|
@RLAlgorithmConfig.register_subclass("sac")
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -102,15 +96,3 @@ class SACAlgorithmConfig(RLAlgorithmConfig):
|
|||||||
policy_config=policy_cfg,
|
policy_config=policy_cfg,
|
||||||
discrete_critic_network_kwargs=policy_cfg.discrete_critic_network_kwargs,
|
discrete_critic_network_kwargs=policy_cfg.discrete_critic_network_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def build_algorithm(self, policy: torch.nn.Module) -> SACAlgorithm:
|
|
||||||
if self.policy_config is None:
|
|
||||||
raise ValueError(
|
|
||||||
"SACAlgorithmConfig.policy_config is None. "
|
|
||||||
"It must be populated (typically by TrainRLServerPipelineConfig.validate) "
|
|
||||||
"before calling build_algorithm()."
|
|
||||||
)
|
|
||||||
|
|
||||||
from .sac_algorithm import SACAlgorithm
|
|
||||||
|
|
||||||
return SACAlgorithm(policy=policy, config=self)
|
|
||||||
|
|||||||
@@ -501,25 +501,17 @@ def test_training_stats_generic_losses():
|
|||||||
|
|
||||||
|
|
||||||
# ===========================================================================
|
# ===========================================================================
|
||||||
# Registry-driven build_algorithm
|
# Registry-driven make_algorithm
|
||||||
# ===========================================================================
|
# ===========================================================================
|
||||||
|
|
||||||
|
|
||||||
def test_build_algorithm_via_config():
|
def test_make_algorithm_builds_sac():
|
||||||
"""SACAlgorithmConfig.build_algorithm should produce a working SACAlgorithm."""
|
"""make_algorithm should look up the SAC class from the registry and instantiate it."""
|
||||||
sac_cfg = _make_sac_config()
|
sac_cfg = _make_sac_config()
|
||||||
algo_config = SACAlgorithmConfig.from_policy_config(sac_cfg)
|
algo_config = SACAlgorithmConfig.from_policy_config(sac_cfg)
|
||||||
algo_config.utd_ratio = 2
|
algo_config.utd_ratio = 2
|
||||||
policy = GaussianActorPolicy(config=sac_cfg)
|
policy = GaussianActorPolicy(config=sac_cfg)
|
||||||
|
|
||||||
algorithm = algo_config.build_algorithm(policy)
|
algorithm = make_algorithm(cfg=algo_config, policy=policy)
|
||||||
assert isinstance(algorithm, SACAlgorithm)
|
assert isinstance(algorithm, SACAlgorithm)
|
||||||
assert algorithm.config.utd_ratio == 2
|
assert algorithm.config.utd_ratio == 2
|
||||||
|
|
||||||
|
|
||||||
def test_make_algorithm_uses_build_algorithm():
|
|
||||||
"""make_algorithm should delegate to config.build_algorithm (no hardcoded if/else)."""
|
|
||||||
sac_cfg = _make_sac_config()
|
|
||||||
policy = GaussianActorPolicy(config=sac_cfg)
|
|
||||||
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
|
|
||||||
assert isinstance(algorithm, SACAlgorithm)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user