mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-26 14:09:47 +00:00
refactor(rl): expose public API in rl/__init__ and use relative imports in sub-packages
This commit is contained in:
+20
-10
@@ -16,19 +16,29 @@
|
||||
Reinforcement learning modules.
|
||||
|
||||
Requires: ``pip install 'lerobot[hilserl]'``
|
||||
|
||||
Available modules (import directly)::
|
||||
|
||||
from lerobot.rl.actor import ...
|
||||
from lerobot.rl.learner import ...
|
||||
from lerobot.rl.learner_service import ...
|
||||
from lerobot.rl.buffer import ...
|
||||
from lerobot.rl.eval_policy import ...
|
||||
from lerobot.rl.gym_manipulator import ...
|
||||
"""
|
||||
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
require_package("grpcio", extra="hilserl", import_name="grpc")
|
||||
|
||||
__all__: list[str] = []
|
||||
from .algorithms.base import RLAlgorithm as RLAlgorithm
|
||||
from .algorithms.configs import RLAlgorithmConfig as RLAlgorithmConfig, TrainingStats as TrainingStats
|
||||
from .algorithms.factory import make_algorithm as make_algorithm
|
||||
from .algorithms.sac import SACAlgorithm as SACAlgorithm, SACAlgorithmConfig as SACAlgorithmConfig
|
||||
from .buffer import ReplayBuffer as ReplayBuffer
|
||||
from .data_sources import DataMixer as DataMixer, OnlineOfflineMixer as OnlineOfflineMixer
|
||||
from .trainer import RLTrainer as RLTrainer
|
||||
|
||||
__all__ = [
|
||||
"RLAlgorithm",
|
||||
"RLAlgorithmConfig",
|
||||
"TrainingStats",
|
||||
"make_algorithm",
|
||||
"SACAlgorithm",
|
||||
"SACAlgorithmConfig",
|
||||
"RLTrainer",
|
||||
"ReplayBuffer",
|
||||
"DataMixer",
|
||||
"OnlineOfflineMixer",
|
||||
]
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from lerobot.rl.algorithms.sac import SACAlgorithm as SACAlgorithm, SACAlgorithmConfig as SACAlgorithmConfig
|
||||
from .sac import SACAlgorithm as SACAlgorithm, SACAlgorithmConfig as SACAlgorithmConfig
|
||||
|
||||
__all__ = [
|
||||
"SACAlgorithm",
|
||||
|
||||
@@ -12,6 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from lerobot.rl.data_sources.data_mixer import BatchType, DataMixer, OnlineOfflineMixer
|
||||
from .data_mixer import BatchType, DataMixer, OnlineOfflineMixer
|
||||
|
||||
__all__ = ["BatchType", "DataMixer", "OnlineOfflineMixer"]
|
||||
|
||||
@@ -22,9 +22,9 @@ import pytest
|
||||
|
||||
pytest.importorskip("grpc")
|
||||
|
||||
from torch.multiprocessing import Queue as TorchMPQueue # noqa: E402
|
||||
from torch.multiprocessing import Queue as TorchMPQueue
|
||||
|
||||
from lerobot.rl.queue import get_last_item_from_queue # noqa: E402
|
||||
from lerobot.rl.queue import get_last_item_from_queue
|
||||
|
||||
|
||||
def test_get_last_item_single_item():
|
||||
|
||||
@@ -24,7 +24,7 @@ import pytest
|
||||
|
||||
pytest.importorskip("grpc")
|
||||
|
||||
from lerobot.rl.process import ProcessSignalHandler # noqa: E402
|
||||
from lerobot.rl.process import ProcessSignalHandler
|
||||
|
||||
|
||||
# Fixture to reset shutdown_event_counter and original signal handlers before and after each test
|
||||
|
||||
Reference in New Issue
Block a user