mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 09:09:48 +00:00
fix(tests): gate RL tests on the datasets extra
This commit is contained in:
@@ -15,15 +15,18 @@
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
|
||||
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import MLP, GaussianActorPolicy
|
||||
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE
|
||||
from lerobot.utils.random_utils import seeded_context, set_seed
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
import torch # noqa: E402
|
||||
from torch import Tensor, nn # noqa: E402
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature # noqa: E402
|
||||
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import GaussianActorConfig # noqa: E402
|
||||
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import MLP, GaussianActorPolicy # noqa: E402
|
||||
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig # noqa: E402
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE # noqa: E402
|
||||
from lerobot.utils.random_utils import seeded_context, set_seed # noqa: E402
|
||||
|
||||
try:
|
||||
import transformers # noqa: F401
|
||||
|
||||
@@ -13,11 +13,15 @@
|
||||
# limitations under the License.
|
||||
"""Tests for RL data mixing (DataMixer, OnlineOfflineMixer)."""
|
||||
|
||||
import torch
|
||||
import pytest
|
||||
|
||||
from lerobot.rl.buffer import ReplayBuffer
|
||||
from lerobot.rl.data_sources import OnlineOfflineMixer
|
||||
from lerobot.utils.constants import OBS_STATE
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
import torch # noqa: E402
|
||||
|
||||
from lerobot.rl.buffer import ReplayBuffer # noqa: E402
|
||||
from lerobot.rl.data_sources import OnlineOfflineMixer # noqa: E402
|
||||
from lerobot.utils.constants import OBS_STATE # noqa: E402
|
||||
|
||||
|
||||
def _make_buffer(capacity: int = 100, state_dim: int = 4) -> ReplayBuffer:
|
||||
|
||||
@@ -18,9 +18,13 @@ import threading
|
||||
import time
|
||||
from queue import Queue
|
||||
|
||||
from torch.multiprocessing import Queue as TorchMPQueue
|
||||
import pytest
|
||||
|
||||
from lerobot.rl.queue import get_last_item_from_queue
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
from torch.multiprocessing import Queue as TorchMPQueue # noqa: E402
|
||||
|
||||
from lerobot.rl.queue import get_last_item_from_queue # noqa: E402
|
||||
|
||||
|
||||
def test_get_last_item_single_item():
|
||||
|
||||
@@ -16,16 +16,19 @@
|
||||
"""Tests for the RL algorithm abstraction and SACAlgorithm implementation."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
|
||||
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
|
||||
from lerobot.rl.algorithms.configs import RLAlgorithmConfig, TrainingStats
|
||||
from lerobot.rl.algorithms.factory import make_algorithm
|
||||
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
import torch # noqa: E402
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature # noqa: E402
|
||||
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import GaussianActorConfig # noqa: E402
|
||||
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy # noqa: E402
|
||||
from lerobot.rl.algorithms.configs import RLAlgorithmConfig, TrainingStats # noqa: E402
|
||||
from lerobot.rl.algorithms.factory import make_algorithm # noqa: E402
|
||||
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig # noqa: E402
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE # noqa: E402
|
||||
from lerobot.utils.random_utils import set_seed # noqa: E402
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers (reuse patterns from tests/policies/test_gaussian_actor_policy.py)
|
||||
|
||||
@@ -14,13 +14,17 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import pytest
|
||||
|
||||
from lerobot.rl.algorithms.base import RLAlgorithm
|
||||
from lerobot.rl.algorithms.configs import TrainingStats
|
||||
from lerobot.rl.trainer import RLTrainer
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
import torch # noqa: E402
|
||||
from torch import Tensor # noqa: E402
|
||||
|
||||
from lerobot.rl.algorithms.base import RLAlgorithm # noqa: E402
|
||||
from lerobot.rl.algorithms.configs import TrainingStats # noqa: E402
|
||||
from lerobot.rl.trainer import RLTrainer # noqa: E402
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE # noqa: E402
|
||||
|
||||
|
||||
class _DummyRLAlgorithmConfig:
|
||||
|
||||
@@ -22,7 +22,9 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.rl.process import ProcessSignalHandler
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
from lerobot.rl.process import ProcessSignalHandler # noqa: E402
|
||||
|
||||
|
||||
# Fixture to reset shutdown_event_counter and original signal handlers before and after each test
|
||||
|
||||
Reference in New Issue
Block a user