diff --git a/src/lerobot/rl/__init__.py b/src/lerobot/rl/__init__.py index 6a7c750d3..314781c9f 100644 --- a/src/lerobot/rl/__init__.py +++ b/src/lerobot/rl/__init__.py @@ -16,19 +16,29 @@ Reinforcement learning modules. Requires: ``pip install 'lerobot[hilserl]'`` - -Available modules (import directly):: - - from lerobot.rl.actor import ... - from lerobot.rl.learner import ... - from lerobot.rl.learner_service import ... - from lerobot.rl.buffer import ... - from lerobot.rl.eval_policy import ... - from lerobot.rl.gym_manipulator import ... """ from lerobot.utils.import_utils import require_package require_package("grpcio", extra="hilserl", import_name="grpc") -__all__: list[str] = [] +from .algorithms.base import RLAlgorithm as RLAlgorithm +from .algorithms.configs import RLAlgorithmConfig as RLAlgorithmConfig, TrainingStats as TrainingStats +from .algorithms.factory import make_algorithm as make_algorithm +from .algorithms.sac import SACAlgorithm as SACAlgorithm, SACAlgorithmConfig as SACAlgorithmConfig +from .buffer import ReplayBuffer as ReplayBuffer +from .data_sources import DataMixer as DataMixer, OnlineOfflineMixer as OnlineOfflineMixer +from .trainer import RLTrainer as RLTrainer + +__all__ = [ + "RLAlgorithm", + "RLAlgorithmConfig", + "TrainingStats", + "make_algorithm", + "SACAlgorithm", + "SACAlgorithmConfig", + "RLTrainer", + "ReplayBuffer", + "DataMixer", + "OnlineOfflineMixer", +] diff --git a/src/lerobot/rl/algorithms/__init__.py b/src/lerobot/rl/algorithms/__init__.py index fe4a51846..31acda717 100644 --- a/src/lerobot/rl/algorithms/__init__.py +++ b/src/lerobot/rl/algorithms/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from lerobot.rl.algorithms.sac import SACAlgorithm as SACAlgorithm, SACAlgorithmConfig as SACAlgorithmConfig +from .sac import SACAlgorithm as SACAlgorithm, SACAlgorithmConfig as SACAlgorithmConfig __all__ = [ "SACAlgorithm", diff --git a/src/lerobot/rl/data_sources/__init__.py b/src/lerobot/rl/data_sources/__init__.py index 4ac97ec1b..b4c0bcf3d 100644 --- a/src/lerobot/rl/data_sources/__init__.py +++ b/src/lerobot/rl/data_sources/__init__.py @@ -12,6 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from lerobot.rl.data_sources.data_mixer import BatchType, DataMixer, OnlineOfflineMixer +from .data_mixer import BatchType, DataMixer, OnlineOfflineMixer __all__ = ["BatchType", "DataMixer", "OnlineOfflineMixer"] diff --git a/tests/rl/test_queue.py b/tests/rl/test_queue.py index cf3d6cdca..a5a96a9d6 100644 --- a/tests/rl/test_queue.py +++ b/tests/rl/test_queue.py @@ -22,9 +22,9 @@ import pytest pytest.importorskip("grpc") -from torch.multiprocessing import Queue as TorchMPQueue # noqa: E402 +from torch.multiprocessing import Queue as TorchMPQueue -from lerobot.rl.queue import get_last_item_from_queue # noqa: E402 +from lerobot.rl.queue import get_last_item_from_queue def test_get_last_item_single_item(): diff --git a/tests/utils/test_process.py b/tests/utils/test_process.py index ce56db173..7ca880057 100644 --- a/tests/utils/test_process.py +++ b/tests/utils/test_process.py @@ -24,7 +24,7 @@ import pytest pytest.importorskip("grpc") -from lerobot.rl.process import ProcessSignalHandler # noqa: E402 +from lerobot.rl.process import ProcessSignalHandler # Fixture to reset shutdown_event_counter and original signal handlers before and after each test