fix(tests): gate RL tests on the datasets extra

This commit is contained in:
Khalil Meftah
2026-04-27 16:53:34 +02:00
parent 577f14337a
commit e298474bf3
6 changed files with 50 additions and 30 deletions
+11 -8
View File
@@ -15,15 +15,18 @@
# limitations under the License.
import pytest
import torch
from torch import Tensor, nn
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import MLP, GaussianActorPolicy
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE
from lerobot.utils.random_utils import seeded_context, set_seed
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
import torch # noqa: E402
from torch import Tensor, nn # noqa: E402
from lerobot.configs.types import FeatureType, PolicyFeature # noqa: E402
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import GaussianActorConfig # noqa: E402
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import MLP, GaussianActorPolicy # noqa: E402
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig # noqa: E402
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE # noqa: E402
from lerobot.utils.random_utils import seeded_context, set_seed # noqa: E402
try:
import transformers # noqa: F401
+8 -4
View File
@@ -13,11 +13,15 @@
# limitations under the License.
"""Tests for RL data mixing (DataMixer, OnlineOfflineMixer)."""
import torch
import pytest
from lerobot.rl.buffer import ReplayBuffer
from lerobot.rl.data_sources import OnlineOfflineMixer
from lerobot.utils.constants import OBS_STATE
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
import torch # noqa: E402
from lerobot.rl.buffer import ReplayBuffer # noqa: E402
from lerobot.rl.data_sources import OnlineOfflineMixer # noqa: E402
from lerobot.utils.constants import OBS_STATE # noqa: E402
def _make_buffer(capacity: int = 100, state_dim: int = 4) -> ReplayBuffer:
+6 -2
View File
@@ -18,9 +18,13 @@ import threading
import time
from queue import Queue
from torch.multiprocessing import Queue as TorchMPQueue
import pytest
from lerobot.rl.queue import get_last_item_from_queue
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
from torch.multiprocessing import Queue as TorchMPQueue # noqa: E402
from lerobot.rl.queue import get_last_item_from_queue # noqa: E402
def test_get_last_item_single_item():
+12 -9
View File
@@ -16,16 +16,19 @@
"""Tests for the RL algorithm abstraction and SACAlgorithm implementation."""
import pytest
import torch
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
from lerobot.rl.algorithms.configs import RLAlgorithmConfig, TrainingStats
from lerobot.rl.algorithms.factory import make_algorithm
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE
from lerobot.utils.random_utils import set_seed
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
import torch # noqa: E402
from lerobot.configs.types import FeatureType, PolicyFeature # noqa: E402
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import GaussianActorConfig # noqa: E402
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy # noqa: E402
from lerobot.rl.algorithms.configs import RLAlgorithmConfig, TrainingStats # noqa: E402
from lerobot.rl.algorithms.factory import make_algorithm # noqa: E402
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig # noqa: E402
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE # noqa: E402
from lerobot.utils.random_utils import set_seed # noqa: E402
# ---------------------------------------------------------------------------
# Helpers (reuse patterns from tests/policies/test_gaussian_actor_policy.py)
+10 -6
View File
@@ -14,13 +14,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import Tensor
import pytest
from lerobot.rl.algorithms.base import RLAlgorithm
from lerobot.rl.algorithms.configs import TrainingStats
from lerobot.rl.trainer import RLTrainer
from lerobot.utils.constants import ACTION, OBS_STATE
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
import torch # noqa: E402
from torch import Tensor # noqa: E402
from lerobot.rl.algorithms.base import RLAlgorithm # noqa: E402
from lerobot.rl.algorithms.configs import TrainingStats # noqa: E402
from lerobot.rl.trainer import RLTrainer # noqa: E402
from lerobot.utils.constants import ACTION, OBS_STATE # noqa: E402
class _DummyRLAlgorithmConfig:
+3 -1
View File
@@ -22,7 +22,9 @@ from unittest.mock import patch
import pytest
from lerobot.rl.process import ProcessSignalHandler
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
from lerobot.rl.process import ProcessSignalHandler # noqa: E402
# Fixture to reset shutdown_event_counter and original signal handlers before and after each test