mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 19:49:49 +00:00
feat(dependencies): minimal default tag install (#3362)
This commit is contained in:
@@ -31,7 +31,7 @@ from lerobot.policies.groot.processor_groot import make_groot_pre_post_processor
|
||||
from lerobot.processor import PolicyProcessorPipeline
|
||||
from lerobot.types import PolicyAction
|
||||
from lerobot.utils.device_utils import auto_select_torch_device
|
||||
from tests.utils import require_cuda # noqa: E402
|
||||
from tests.utils import require_cuda
|
||||
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.policies.sac.reward_model.modeling_classifier import ClassifierOutput
|
||||
from lerobot.utils.constants import OBS_IMAGE, REWARD
|
||||
from tests.utils import require_package
|
||||
from tests.utils import skip_if_package_missing
|
||||
|
||||
|
||||
def test_classifier_output():
|
||||
@@ -37,7 +37,7 @@ def test_classifier_output():
|
||||
)
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
@skip_if_package_missing("transformers")
|
||||
@pytest.mark.skip(
|
||||
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
|
||||
)
|
||||
@@ -81,7 +81,7 @@ def test_binary_classifier_with_default_params():
|
||||
assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values"
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
@skip_if_package_missing("transformers")
|
||||
@pytest.mark.skip(
|
||||
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
|
||||
)
|
||||
@@ -123,7 +123,7 @@ def test_multiclass_classifier():
|
||||
assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values"
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
@skip_if_package_missing("transformers")
|
||||
@pytest.mark.skip(
|
||||
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
|
||||
)
|
||||
@@ -138,7 +138,7 @@ def test_default_device():
|
||||
assert p.device == torch.device("cpu")
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
@skip_if_package_missing("transformers")
|
||||
@pytest.mark.skip(
|
||||
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
|
||||
)
|
||||
|
||||
@@ -19,15 +19,15 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature, RTCAttentionSchedule # noqa: E402
|
||||
from lerobot.policies.factory import make_pre_post_processors # noqa: E402
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig # noqa: E402
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature, RTCAttentionSchedule
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig # noqa: F401
|
||||
from lerobot.utils.random_utils import set_seed # noqa: E402
|
||||
from tests.utils import require_cuda, require_package # noqa: E402
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
from tests.utils import require_cuda, skip_if_package_missing
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
@skip_if_package_missing("transformers")
|
||||
@require_cuda
|
||||
def test_smolvla_rtc_initialization():
|
||||
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy # noqa: F401
|
||||
@@ -65,7 +65,7 @@ def test_smolvla_rtc_initialization():
|
||||
print("✓ SmolVLA RTC initialization: Test passed")
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
@skip_if_package_missing("transformers")
|
||||
@require_cuda
|
||||
def test_smolvla_rtc_initialization_without_rtc_config():
|
||||
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy # noqa: F401
|
||||
@@ -87,7 +87,7 @@ def test_smolvla_rtc_initialization_without_rtc_config():
|
||||
print("✓ SmolVLA RTC initialization without RTC config: Test passed")
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
@skip_if_package_missing("transformers")
|
||||
@require_cuda
|
||||
@pytest.mark.skipif(True, reason="Requires pretrained SmolVLA model weights")
|
||||
def test_smolvla_rtc_inference_with_prev_chunk():
|
||||
@@ -170,7 +170,7 @@ def test_smolvla_rtc_inference_with_prev_chunk():
|
||||
print("✓ SmolVLA RTC inference with prev_chunk: Test passed")
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
@skip_if_package_missing("transformers")
|
||||
@require_cuda
|
||||
@pytest.mark.skipif(True, reason="Requires pretrained SmolVLA model weights")
|
||||
def test_smolvla_rtc_inference_without_prev_chunk():
|
||||
@@ -244,7 +244,7 @@ def test_smolvla_rtc_inference_without_prev_chunk():
|
||||
print("✓ SmolVLA RTC inference without prev_chunk: Test passed")
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
@skip_if_package_missing("transformers")
|
||||
@require_cuda
|
||||
@pytest.mark.skipif(True, reason="Requires pretrained SmolVLA model weights")
|
||||
def test_smolvla_rtc_validation_rules():
|
||||
|
||||
@@ -20,16 +20,16 @@ from pathlib import Path
|
||||
import einops
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
from packaging import version
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from lerobot import available_policies
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.datasets.factory import make_dataset
|
||||
from lerobot.datasets.feature_utils import dataset_to_policy_features
|
||||
from lerobot.datasets.utils import cycle
|
||||
from lerobot.datasets import make_dataset
|
||||
from lerobot.envs.factory import make_env, make_env_config
|
||||
from lerobot.envs.utils import close_envs, preprocess_observation
|
||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||
@@ -45,10 +45,23 @@ from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.policies.vqbet.modeling_vqbet import VQBeTHead
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.utils.feature_utils import dataset_to_policy_features
|
||||
from lerobot.utils.import_utils import is_package_available
|
||||
from lerobot.utils.random_utils import seeded_context
|
||||
from lerobot.utils.utils import cycle
|
||||
from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats
|
||||
from tests.utils import DEVICE, require_cpu, require_env, require_x86_64_kernel
|
||||
|
||||
# Policies that require optional heavy dependencies to instantiate
|
||||
_POLICY_REQUIRED_PACKAGES: dict[str, tuple[str, ...]] = {
|
||||
"diffusion": ("diffusers",),
|
||||
}
|
||||
|
||||
_ALL_POLICIES = ["act", "diffusion", "tdmpc", "vqbet"]
|
||||
AVAILABLE_POLICIES = [
|
||||
p for p in _ALL_POLICIES if all(is_package_available(pkg) for pkg in _POLICY_REQUIRED_PACKAGES.get(p, ()))
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_dataset_metadata(lerobot_dataset_metadata_factory, info_factory, tmp_path):
|
||||
@@ -84,7 +97,7 @@ def dummy_dataset_metadata(lerobot_dataset_metadata_factory, info_factory, tmp_p
|
||||
return ds_meta
|
||||
|
||||
|
||||
@pytest.mark.parametrize("policy_name", available_policies)
|
||||
@pytest.mark.parametrize("policy_name", AVAILABLE_POLICIES)
|
||||
def test_get_policy_and_config_classes(policy_name: str):
|
||||
"""Check that the correct policy and config classes are returned."""
|
||||
policy_cls = get_policy_class(policy_name)
|
||||
@@ -255,7 +268,7 @@ def test_act_backbone_lr():
|
||||
assert len(optimizer.param_groups[1]["params"]) == 20
|
||||
|
||||
|
||||
@pytest.mark.parametrize("policy_name", available_policies)
|
||||
@pytest.mark.parametrize("policy_name", AVAILABLE_POLICIES)
|
||||
def test_policy_defaults(dummy_dataset_metadata, policy_name: str):
|
||||
"""Check that the policy can be instantiated with defaults."""
|
||||
policy_cls = get_policy_class(policy_name)
|
||||
@@ -268,7 +281,7 @@ def test_policy_defaults(dummy_dataset_metadata, policy_name: str):
|
||||
policy_cls(policy_cfg)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("policy_name", available_policies)
|
||||
@pytest.mark.parametrize("policy_name", AVAILABLE_POLICIES)
|
||||
def test_save_and_load_pretrained(dummy_dataset_metadata, tmp_path, policy_name: str):
|
||||
policy_cls = get_policy_class(policy_name)
|
||||
policy_cfg = make_policy_config(policy_name)
|
||||
@@ -343,7 +356,7 @@ def test_multikey_construction(multikey: bool):
|
||||
# to normalize the image at all. In our current codebase we dont normalize at all. But there is still a minor difference
|
||||
# that fails the test. However, by testing to normalize the image with 0.5 0.5 in the current codebase, the test pass.
|
||||
# Thus, we deactivate this test for now.
|
||||
(
|
||||
pytest.param(
|
||||
"lerobot/pusht",
|
||||
"diffusion",
|
||||
{
|
||||
@@ -352,6 +365,7 @@ def test_multikey_construction(multikey: bool):
|
||||
"down_dims": [128, 256, 512],
|
||||
},
|
||||
"",
|
||||
marks=pytest.mark.skipif(not is_package_available("diffusers"), reason="diffusers not installed"),
|
||||
),
|
||||
("lerobot/aloha_sim_insertion_human", "act", {"n_action_steps": 10}, ""),
|
||||
(
|
||||
|
||||
@@ -10,6 +10,8 @@ import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.datasets.compute_stats import get_feature_stats
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
Reference in New Issue
Block a user