mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 00:37:10 +00:00
refactor: implement NotImplementedError for abstract methods in RLAlgorithm and DataMixer
This commit is contained in:
@@ -44,7 +44,7 @@ class RLAlgorithm(abc.ABC):
|
||||
The iterator is owned by the trainer; the algorithm just consumes
|
||||
from it.
|
||||
"""
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def configure_data_iterator(
|
||||
self,
|
||||
@@ -65,13 +65,13 @@ class RLAlgorithm(abc.ABC):
|
||||
queue_size=queue_size,
|
||||
)
|
||||
|
||||
@abc.abstractmethod
|
||||
def make_optimizers_and_scheduler(self) -> dict[str, Optimizer]:
|
||||
"""Create, store, and return the optimizers needed for training.
|
||||
"""Build and return the optimizers used during training.
|
||||
|
||||
Called on the **learner** side after construction. Subclasses must
|
||||
override this with algorithm-specific optimizer setup.
|
||||
Called once on the learner side after construction.
|
||||
"""
|
||||
return {}
|
||||
raise NotImplementedError
|
||||
|
||||
def get_optimizers(self) -> dict[str, Optimizer]:
|
||||
"""Return optimizers for checkpointing / external scheduling."""
|
||||
@@ -97,3 +97,4 @@ class RLAlgorithm(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None:
|
||||
"""Load policy state-dict received from the learner."""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -27,7 +27,7 @@ class DataMixer(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def sample(self, batch_size: int) -> BatchType:
|
||||
"""Draw one batch of ``batch_size`` transitions."""
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def get_iterator(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user