refactor(rl): expose public API in rl/__init__ and use relative imports in sub-packages

This commit is contained in:
Khalil Meftah
2026-04-16 15:46:34 +02:00
parent a5ad273b62
commit d7e25c8326
5 changed files with 25 additions and 15 deletions
+20 -10
View File
@@ -16,19 +16,29 @@
Reinforcement learning modules. Reinforcement learning modules.
Requires: ``pip install 'lerobot[hilserl]'`` Requires: ``pip install 'lerobot[hilserl]'``
Available modules (import directly)::
from lerobot.rl.actor import ...
from lerobot.rl.learner import ...
from lerobot.rl.learner_service import ...
from lerobot.rl.buffer import ...
from lerobot.rl.eval_policy import ...
from lerobot.rl.gym_manipulator import ...
""" """
from lerobot.utils.import_utils import require_package from lerobot.utils.import_utils import require_package
require_package("grpcio", extra="hilserl", import_name="grpc") require_package("grpcio", extra="hilserl", import_name="grpc")
__all__: list[str] = [] from .algorithms.base import RLAlgorithm as RLAlgorithm
from .algorithms.configs import RLAlgorithmConfig as RLAlgorithmConfig, TrainingStats as TrainingStats
from .algorithms.factory import make_algorithm as make_algorithm
from .algorithms.sac import SACAlgorithm as SACAlgorithm, SACAlgorithmConfig as SACAlgorithmConfig
from .buffer import ReplayBuffer as ReplayBuffer
from .data_sources import DataMixer as DataMixer, OnlineOfflineMixer as OnlineOfflineMixer
from .trainer import RLTrainer as RLTrainer
__all__ = [
"RLAlgorithm",
"RLAlgorithmConfig",
"TrainingStats",
"make_algorithm",
"SACAlgorithm",
"SACAlgorithmConfig",
"RLTrainer",
"ReplayBuffer",
"DataMixer",
"OnlineOfflineMixer",
]
+1 -1
View File
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from lerobot.rl.algorithms.sac import SACAlgorithm as SACAlgorithm, SACAlgorithmConfig as SACAlgorithmConfig from .sac import SACAlgorithm as SACAlgorithm, SACAlgorithmConfig as SACAlgorithmConfig
__all__ = [ __all__ = [
"SACAlgorithm", "SACAlgorithm",
+1 -1
View File
@@ -12,6 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from lerobot.rl.data_sources.data_mixer import BatchType, DataMixer, OnlineOfflineMixer from .data_mixer import BatchType, DataMixer, OnlineOfflineMixer
__all__ = ["BatchType", "DataMixer", "OnlineOfflineMixer"] __all__ = ["BatchType", "DataMixer", "OnlineOfflineMixer"]
+2 -2
View File
@@ -22,9 +22,9 @@ import pytest
pytest.importorskip("grpc") pytest.importorskip("grpc")
from torch.multiprocessing import Queue as TorchMPQueue # noqa: E402 from torch.multiprocessing import Queue as TorchMPQueue
from lerobot.rl.queue import get_last_item_from_queue # noqa: E402 from lerobot.rl.queue import get_last_item_from_queue
def test_get_last_item_single_item(): def test_get_last_item_single_item():
+1 -1
View File
@@ -24,7 +24,7 @@ import pytest
pytest.importorskip("grpc") pytest.importorskip("grpc")
from lerobot.rl.process import ProcessSignalHandler # noqa: E402 from lerobot.rl.process import ProcessSignalHandler
# Fixture to reset shutdown_event_counter and original signal handlers before and after each test # Fixture to reset shutdown_event_counter and original signal handlers before and after each test