diff --git a/tests/optim/test_schedulers.py b/tests/optim/test_schedulers.py index 224613416..5d6687102 100644 --- a/tests/optim/test_schedulers.py +++ b/tests/optim/test_schedulers.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import pytest import torch from packaging.version import Version from torch.optim.lr_scheduler import LambdaLR @@ -23,8 +24,10 @@ from lerobot.optim.schedulers import ( save_scheduler_state, ) from lerobot.utils.constants import SCHEDULER_STATE +from lerobot.utils.import_utils import is_package_available +@pytest.mark.skipif(not is_package_available("diffusers"), reason="diffusers not installed") def test_diffuser_scheduler(optimizer): config = DiffuserSchedulerConfig(name="cosine", num_warmup_steps=5) scheduler = config.build(optimizer, num_training_steps=100) diff --git a/tests/policies/test_policies.py b/tests/policies/test_policies.py index b5cbdec8d..46396b6c5 100644 --- a/tests/policies/test_policies.py +++ b/tests/policies/test_policies.py @@ -43,12 +43,21 @@ from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig from lerobot.policies.vqbet.modeling_vqbet import VQBeTHead from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE from lerobot.utils.feature_utils import dataset_to_policy_features +from lerobot.utils.import_utils import is_package_available from lerobot.utils.random_utils import seeded_context from lerobot.utils.utils import cycle from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats from tests.utils import DEVICE, require_cpu, require_env, require_x86_64_kernel -AVAILABLE_POLICIES = ["act", "diffusion", "tdmpc", "vqbet"] +# Policies that require optional heavy dependencies to instantiate +_POLICY_REQUIRED_PACKAGES: dict[str, tuple[str, ...]] = { + "diffusion": ("diffusers",), +} + +_ALL_POLICIES = ["act", "diffusion", "tdmpc", "vqbet"] +AVAILABLE_POLICIES = [ + p for p in _ALL_POLICIES if all(is_package_available(pkg) for pkg in _POLICY_REQUIRED_PACKAGES.get(p, ())) +] @pytest.fixture @@ -344,7 +353,7 @@ def test_multikey_construction(multikey: bool): # to normalize the image at all. In our current codebase we dont normalize at all. But there is still a minor difference # that fails the test. However, by testing to normalize the image with 0.5 0.5 in the current codebase, the test pass. # Thus, we deactivate this test for now. - ( + pytest.param( "lerobot/pusht", "diffusion", { @@ -353,6 +362,7 @@ def test_multikey_construction(multikey: bool): "down_dims": [128, 256, 512], }, "", + marks=pytest.mark.skipif(not is_package_available("diffusers"), reason="diffusers not installed"), ), ("lerobot/aloha_sim_insertion_human", "act", {"n_action_steps": 10}, ""), (