mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 11:39:50 +00:00
refactor(rl): expose public API in rl/__init__ and use relative imports in sub-packages
This commit is contained in:
+20
-10
@@ -16,19 +16,29 @@
|
|||||||
Reinforcement learning modules.
|
Reinforcement learning modules.
|
||||||
|
|
||||||
Requires: ``pip install 'lerobot[hilserl]'``
|
Requires: ``pip install 'lerobot[hilserl]'``
|
||||||
|
|
||||||
Available modules (import directly)::
|
|
||||||
|
|
||||||
from lerobot.rl.actor import ...
|
|
||||||
from lerobot.rl.learner import ...
|
|
||||||
from lerobot.rl.learner_service import ...
|
|
||||||
from lerobot.rl.buffer import ...
|
|
||||||
from lerobot.rl.eval_policy import ...
|
|
||||||
from lerobot.rl.gym_manipulator import ...
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from lerobot.utils.import_utils import require_package
|
from lerobot.utils.import_utils import require_package
|
||||||
|
|
||||||
require_package("grpcio", extra="hilserl", import_name="grpc")
|
require_package("grpcio", extra="hilserl", import_name="grpc")
|
||||||
|
|
||||||
__all__: list[str] = []
|
from .algorithms.base import RLAlgorithm as RLAlgorithm
|
||||||
|
from .algorithms.configs import RLAlgorithmConfig as RLAlgorithmConfig, TrainingStats as TrainingStats
|
||||||
|
from .algorithms.factory import make_algorithm as make_algorithm
|
||||||
|
from .algorithms.sac import SACAlgorithm as SACAlgorithm, SACAlgorithmConfig as SACAlgorithmConfig
|
||||||
|
from .buffer import ReplayBuffer as ReplayBuffer
|
||||||
|
from .data_sources import DataMixer as DataMixer, OnlineOfflineMixer as OnlineOfflineMixer
|
||||||
|
from .trainer import RLTrainer as RLTrainer
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"RLAlgorithm",
|
||||||
|
"RLAlgorithmConfig",
|
||||||
|
"TrainingStats",
|
||||||
|
"make_algorithm",
|
||||||
|
"SACAlgorithm",
|
||||||
|
"SACAlgorithmConfig",
|
||||||
|
"RLTrainer",
|
||||||
|
"ReplayBuffer",
|
||||||
|
"DataMixer",
|
||||||
|
"OnlineOfflineMixer",
|
||||||
|
]
|
||||||
|
|||||||
@@ -12,7 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from lerobot.rl.algorithms.sac import SACAlgorithm as SACAlgorithm, SACAlgorithmConfig as SACAlgorithmConfig
|
from .sac import SACAlgorithm as SACAlgorithm, SACAlgorithmConfig as SACAlgorithmConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"SACAlgorithm",
|
"SACAlgorithm",
|
||||||
|
|||||||
@@ -12,6 +12,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from lerobot.rl.data_sources.data_mixer import BatchType, DataMixer, OnlineOfflineMixer
|
from .data_mixer import BatchType, DataMixer, OnlineOfflineMixer
|
||||||
|
|
||||||
__all__ = ["BatchType", "DataMixer", "OnlineOfflineMixer"]
|
__all__ = ["BatchType", "DataMixer", "OnlineOfflineMixer"]
|
||||||
|
|||||||
@@ -22,9 +22,9 @@ import pytest
|
|||||||
|
|
||||||
pytest.importorskip("grpc")
|
pytest.importorskip("grpc")
|
||||||
|
|
||||||
from torch.multiprocessing import Queue as TorchMPQueue # noqa: E402
|
from torch.multiprocessing import Queue as TorchMPQueue
|
||||||
|
|
||||||
from lerobot.rl.queue import get_last_item_from_queue # noqa: E402
|
from lerobot.rl.queue import get_last_item_from_queue
|
||||||
|
|
||||||
|
|
||||||
def test_get_last_item_single_item():
|
def test_get_last_item_single_item():
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ import pytest
|
|||||||
|
|
||||||
pytest.importorskip("grpc")
|
pytest.importorskip("grpc")
|
||||||
|
|
||||||
from lerobot.rl.process import ProcessSignalHandler # noqa: E402
|
from lerobot.rl.process import ProcessSignalHandler
|
||||||
|
|
||||||
|
|
||||||
# Fixture to reset shutdown_event_counter and original signal handlers before and after each test
|
# Fixture to reset shutdown_event_counter and original signal handlers before and after each test
|
||||||
|
|||||||
Reference in New Issue
Block a user