mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 20:50:02 +00:00
refactor: enforce mandatory config_class and name attributes in RLAlgorithm
This commit is contained in:
@@ -32,15 +32,8 @@ if TYPE_CHECKING:
|
|||||||
class RLAlgorithm(abc.ABC):
|
class RLAlgorithm(abc.ABC):
|
||||||
"""Base for all RL algorithms."""
|
"""Base for all RL algorithms."""
|
||||||
|
|
||||||
config_class: type[RLAlgorithmConfig] | None = None
|
config_class: type[RLAlgorithmConfig]
|
||||||
name: str | None = None
|
name: str
|
||||||
|
|
||||||
def __init_subclass__(cls, **kwargs):
|
|
||||||
super().__init_subclass__(**kwargs)
|
|
||||||
if not getattr(cls, "config_class", None):
|
|
||||||
raise TypeError(f"Class {cls.__name__} must define 'config_class'")
|
|
||||||
if not getattr(cls, "name", None):
|
|
||||||
raise TypeError(f"Class {cls.__name__} must define 'name'")
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats:
|
def update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats:
|
||||||
|
|||||||
Reference in New Issue
Block a user