refactor: implement NotImplementedError for abstract methods in RLAlgorithm and DataMixer

This commit is contained in:
Khalil Meftah
2026-05-07 11:48:41 +02:00
parent 758964984c
commit f1bdd6744f
2 changed files with 7 additions and 6 deletions
+6 -5
View File
@@ -44,7 +44,7 @@ class RLAlgorithm(abc.ABC):
The iterator is owned by the trainer; the algorithm just consumes
from it.
"""
...
raise NotImplementedError
def configure_data_iterator(
self,
@@ -65,13 +65,13 @@ class RLAlgorithm(abc.ABC):
queue_size=queue_size,
)
@abc.abstractmethod
def make_optimizers_and_scheduler(self) -> dict[str, Optimizer]:
"""Create, store, and return the optimizers needed for training.
"""Build and return the optimizers used during training.
Called on the **learner** side after construction. Subclasses must
override this with algorithm-specific optimizer setup.
Called once on the learner side after construction.
"""
return {}
raise NotImplementedError
def get_optimizers(self) -> dict[str, Optimizer]:
"""Return optimizers for checkpointing / external scheduling."""
@@ -97,3 +97,4 @@ class RLAlgorithm(abc.ABC):
@abc.abstractmethod
def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None:
"""Load policy state-dict received from the learner."""
raise NotImplementedError
+1 -1
View File
@@ -27,7 +27,7 @@ class DataMixer(abc.ABC):
@abc.abstractmethod
def sample(self, batch_size: int) -> BatchType:
"""Draw one batch of ``batch_size`` transitions."""
...
raise NotImplementedError
def get_iterator(
self,