mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-17 09:39:47 +00:00
refactor: enforce mandatory config_class and name attributes in RLAlgorithm
This commit is contained in:
@@ -32,15 +32,8 @@ if TYPE_CHECKING:
|
||||
class RLAlgorithm(abc.ABC):
|
||||
"""Base for all RL algorithms."""
|
||||
|
||||
config_class: type[RLAlgorithmConfig] | None = None
|
||||
name: str | None = None
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
if not getattr(cls, "config_class", None):
|
||||
raise TypeError(f"Class {cls.__name__} must define 'config_class'")
|
||||
if not getattr(cls, "name", None):
|
||||
raise TypeError(f"Class {cls.__name__} must define 'name'")
|
||||
config_class: type[RLAlgorithmConfig]
|
||||
name: str
|
||||
|
||||
@abc.abstractmethod
|
||||
def update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats:
|
||||
|
||||
Reference in New Issue
Block a user