From e298474bf3d03b18243f3874a086f14cd6965f5d Mon Sep 17 00:00:00 2001 From: Khalil Meftah Date: Mon, 27 Apr 2026 16:53:34 +0200 Subject: [PATCH] fix(tests): gate RL tests on the `datasets` extra --- tests/policies/test_gaussian_actor_policy.py | 19 ++++++++++-------- tests/rl/test_data_mixer.py | 12 +++++++---- tests/rl/test_queue.py | 8 ++++++-- tests/rl/test_sac_algorithm.py | 21 +++++++++++--------- tests/rl/test_trainer.py | 16 +++++++++------ tests/utils/test_process.py | 4 +++- 6 files changed, 50 insertions(+), 30 deletions(-) diff --git a/tests/policies/test_gaussian_actor_policy.py b/tests/policies/test_gaussian_actor_policy.py index a4af959f2..af802d26f 100644 --- a/tests/policies/test_gaussian_actor_policy.py +++ b/tests/policies/test_gaussian_actor_policy.py @@ -15,15 +15,18 @@ # limitations under the License. import pytest -import torch -from torch import Tensor, nn -from lerobot.configs.types import FeatureType, PolicyFeature -from lerobot.policies.gaussian_actor.configuration_gaussian_actor import GaussianActorConfig -from lerobot.policies.gaussian_actor.modeling_gaussian_actor import MLP, GaussianActorPolicy -from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig -from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE -from lerobot.utils.random_utils import seeded_context, set_seed +pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") + +import torch # noqa: E402 +from torch import Tensor, nn # noqa: E402 + +from lerobot.configs.types import FeatureType, PolicyFeature # noqa: E402 +from lerobot.policies.gaussian_actor.configuration_gaussian_actor import GaussianActorConfig # noqa: E402 +from lerobot.policies.gaussian_actor.modeling_gaussian_actor import MLP, GaussianActorPolicy # noqa: E402 +from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig # noqa: E402 +from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE # noqa: E402 +from lerobot.utils.random_utils import seeded_context, set_seed # noqa: E402 try: import transformers # noqa: F401 diff --git a/tests/rl/test_data_mixer.py b/tests/rl/test_data_mixer.py index 90e9e492f..b153498d7 100644 --- a/tests/rl/test_data_mixer.py +++ b/tests/rl/test_data_mixer.py @@ -13,11 +13,15 @@ # limitations under the License. """Tests for RL data mixing (DataMixer, OnlineOfflineMixer).""" -import torch +import pytest -from lerobot.rl.buffer import ReplayBuffer -from lerobot.rl.data_sources import OnlineOfflineMixer -from lerobot.utils.constants import OBS_STATE +pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") + +import torch # noqa: E402 + +from lerobot.rl.buffer import ReplayBuffer # noqa: E402 +from lerobot.rl.data_sources import OnlineOfflineMixer # noqa: E402 +from lerobot.utils.constants import OBS_STATE # noqa: E402 def _make_buffer(capacity: int = 100, state_dim: int = 4) -> ReplayBuffer: diff --git a/tests/rl/test_queue.py b/tests/rl/test_queue.py index b6716fbd6..77936d269 100644 --- a/tests/rl/test_queue.py +++ b/tests/rl/test_queue.py @@ -18,9 +18,13 @@ import threading import time from queue import Queue -from torch.multiprocessing import Queue as TorchMPQueue +import pytest -from lerobot.rl.queue import get_last_item_from_queue +pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") + +from torch.multiprocessing import Queue as TorchMPQueue # noqa: E402 + +from lerobot.rl.queue import get_last_item_from_queue # noqa: E402 def test_get_last_item_single_item(): diff --git a/tests/rl/test_sac_algorithm.py b/tests/rl/test_sac_algorithm.py index 6f04db8b7..ffecd17b8 100644 --- a/tests/rl/test_sac_algorithm.py +++ b/tests/rl/test_sac_algorithm.py @@ -16,16 +16,19 @@ """Tests for the RL algorithm abstraction and SACAlgorithm implementation.""" import pytest -import torch -from lerobot.configs.types import FeatureType, PolicyFeature -from lerobot.policies.gaussian_actor.configuration_gaussian_actor import GaussianActorConfig -from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy -from lerobot.rl.algorithms.configs import RLAlgorithmConfig, TrainingStats -from lerobot.rl.algorithms.factory import make_algorithm -from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig -from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE -from lerobot.utils.random_utils import set_seed +pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") + +import torch # noqa: E402 + +from lerobot.configs.types import FeatureType, PolicyFeature # noqa: E402 +from lerobot.policies.gaussian_actor.configuration_gaussian_actor import GaussianActorConfig # noqa: E402 +from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy # noqa: E402 +from lerobot.rl.algorithms.configs import RLAlgorithmConfig, TrainingStats # noqa: E402 +from lerobot.rl.algorithms.factory import make_algorithm # noqa: E402 +from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig # noqa: E402 +from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE # noqa: E402 +from lerobot.utils.random_utils import set_seed # noqa: E402 # --------------------------------------------------------------------------- # Helpers (reuse patterns from tests/policies/test_gaussian_actor_policy.py) diff --git a/tests/rl/test_trainer.py b/tests/rl/test_trainer.py index 47eaf6ad3..2970f9bc6 100644 --- a/tests/rl/test_trainer.py +++ b/tests/rl/test_trainer.py @@ -14,13 +14,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch -from torch import Tensor +import pytest -from lerobot.rl.algorithms.base import RLAlgorithm -from lerobot.rl.algorithms.configs import TrainingStats -from lerobot.rl.trainer import RLTrainer -from lerobot.utils.constants import ACTION, OBS_STATE +pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") + +import torch # noqa: E402 +from torch import Tensor # noqa: E402 + +from lerobot.rl.algorithms.base import RLAlgorithm # noqa: E402 +from lerobot.rl.algorithms.configs import TrainingStats # noqa: E402 +from lerobot.rl.trainer import RLTrainer # noqa: E402 +from lerobot.utils.constants import ACTION, OBS_STATE # noqa: E402 class _DummyRLAlgorithmConfig: diff --git a/tests/utils/test_process.py b/tests/utils/test_process.py index e2b00cae9..c19d6677b 100644 --- a/tests/utils/test_process.py +++ b/tests/utils/test_process.py @@ -22,7 +22,9 @@ from unittest.mock import patch import pytest -from lerobot.rl.process import ProcessSignalHandler +pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") + +from lerobot.rl.process import ProcessSignalHandler # noqa: E402 # Fixture to reset shutdown_event_counter and original signal handlers before and after each test