fix(tests): gate RL tests on the datasets extra

This commit is contained in:
Khalil Meftah
2026-04-27 16:53:34 +02:00
parent 577f14337a
commit e298474bf3
6 changed files with 50 additions and 30 deletions
+11 -8
View File
@@ -15,15 +15,18 @@
# limitations under the License. # limitations under the License.
import pytest import pytest
import torch
from torch import Tensor, nn
from lerobot.configs.types import FeatureType, PolicyFeature pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import MLP, GaussianActorPolicy import torch # noqa: E402
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig from torch import Tensor, nn # noqa: E402
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE
from lerobot.utils.random_utils import seeded_context, set_seed from lerobot.configs.types import FeatureType, PolicyFeature # noqa: E402
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import GaussianActorConfig # noqa: E402
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import MLP, GaussianActorPolicy # noqa: E402
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig # noqa: E402
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE # noqa: E402
from lerobot.utils.random_utils import seeded_context, set_seed # noqa: E402
try: try:
import transformers # noqa: F401 import transformers # noqa: F401
+8 -4
View File
@@ -13,11 +13,15 @@
# limitations under the License. # limitations under the License.
"""Tests for RL data mixing (DataMixer, OnlineOfflineMixer).""" """Tests for RL data mixing (DataMixer, OnlineOfflineMixer)."""
import torch import pytest
from lerobot.rl.buffer import ReplayBuffer pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
from lerobot.rl.data_sources import OnlineOfflineMixer
from lerobot.utils.constants import OBS_STATE import torch # noqa: E402
from lerobot.rl.buffer import ReplayBuffer # noqa: E402
from lerobot.rl.data_sources import OnlineOfflineMixer # noqa: E402
from lerobot.utils.constants import OBS_STATE # noqa: E402
def _make_buffer(capacity: int = 100, state_dim: int = 4) -> ReplayBuffer: def _make_buffer(capacity: int = 100, state_dim: int = 4) -> ReplayBuffer:
+6 -2
View File
@@ -18,9 +18,13 @@ import threading
import time import time
from queue import Queue from queue import Queue
from torch.multiprocessing import Queue as TorchMPQueue import pytest
from lerobot.rl.queue import get_last_item_from_queue pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
from torch.multiprocessing import Queue as TorchMPQueue # noqa: E402
from lerobot.rl.queue import get_last_item_from_queue # noqa: E402
def test_get_last_item_single_item(): def test_get_last_item_single_item():
+12 -9
View File
@@ -16,16 +16,19 @@
"""Tests for the RL algorithm abstraction and SACAlgorithm implementation.""" """Tests for the RL algorithm abstraction and SACAlgorithm implementation."""
import pytest import pytest
import torch
from lerobot.configs.types import FeatureType, PolicyFeature pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy import torch # noqa: E402
from lerobot.rl.algorithms.configs import RLAlgorithmConfig, TrainingStats
from lerobot.rl.algorithms.factory import make_algorithm from lerobot.configs.types import FeatureType, PolicyFeature # noqa: E402
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig from lerobot.policies.gaussian_actor.configuration_gaussian_actor import GaussianActorConfig # noqa: E402
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy # noqa: E402
from lerobot.utils.random_utils import set_seed from lerobot.rl.algorithms.configs import RLAlgorithmConfig, TrainingStats # noqa: E402
from lerobot.rl.algorithms.factory import make_algorithm # noqa: E402
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig # noqa: E402
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE # noqa: E402
from lerobot.utils.random_utils import set_seed # noqa: E402
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Helpers (reuse patterns from tests/policies/test_gaussian_actor_policy.py) # Helpers (reuse patterns from tests/policies/test_gaussian_actor_policy.py)
+10 -6
View File
@@ -14,13 +14,17 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import torch import pytest
from torch import Tensor
from lerobot.rl.algorithms.base import RLAlgorithm pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
from lerobot.rl.algorithms.configs import TrainingStats
from lerobot.rl.trainer import RLTrainer import torch # noqa: E402
from lerobot.utils.constants import ACTION, OBS_STATE from torch import Tensor # noqa: E402
from lerobot.rl.algorithms.base import RLAlgorithm # noqa: E402
from lerobot.rl.algorithms.configs import TrainingStats # noqa: E402
from lerobot.rl.trainer import RLTrainer # noqa: E402
from lerobot.utils.constants import ACTION, OBS_STATE # noqa: E402
class _DummyRLAlgorithmConfig: class _DummyRLAlgorithmConfig:
+3 -1
View File
@@ -22,7 +22,9 @@ from unittest.mock import patch
import pytest import pytest
from lerobot.rl.process import ProcessSignalHandler pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
from lerobot.rl.process import ProcessSignalHandler # noqa: E402
# Fixture to reset shutdown_event_counter and original signal handlers before and after each test # Fixture to reset shutdown_event_counter and original signal handlers before and after each test