refactor: enforce mandatory config_class and name attributes in RLAlgorithm

This commit is contained in:
Khalil Meftah
2026-05-07 11:37:02 +02:00
parent 84f74cf0bf
commit 758964984c
+2 -9
View File
@@ -32,15 +32,8 @@ if TYPE_CHECKING:
class RLAlgorithm(abc.ABC):
"""Base for all RL algorithms."""
config_class: type[RLAlgorithmConfig] | None = None
name: str | None = None
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if not getattr(cls, "config_class", None):
raise TypeError(f"Class {cls.__name__} must define 'config_class'")
if not getattr(cls, "name", None):
raise TypeError(f"Class {cls.__name__} must define 'name'")
config_class: type[RLAlgorithmConfig]
name: str
@abc.abstractmethod
def update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats: