mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 03:30:10 +00:00
refactor import fixes
This commit is contained in:
@@ -23,7 +23,6 @@ import torch
|
||||
from packaging import version
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from lerobot import available_policies
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
@@ -49,6 +48,8 @@ from lerobot.utils.utils import cycle
|
||||
from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats
|
||||
from tests.utils import DEVICE, require_cpu, require_env, require_x86_64_kernel
|
||||
|
||||
AVAILABLE_POLICIES = ["act", "diffusion", "tdmpc", "vqbet"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_dataset_metadata(lerobot_dataset_metadata_factory, info_factory, tmp_path):
|
||||
@@ -84,7 +85,7 @@ def dummy_dataset_metadata(lerobot_dataset_metadata_factory, info_factory, tmp_p
|
||||
return ds_meta
|
||||
|
||||
|
||||
@pytest.mark.parametrize("policy_name", available_policies)
|
||||
@pytest.mark.parametrize("policy_name", AVAILABLE_POLICIES)
|
||||
def test_get_policy_and_config_classes(policy_name: str):
|
||||
"""Check that the correct policy and config classes are returned."""
|
||||
policy_cls = get_policy_class(policy_name)
|
||||
@@ -255,7 +256,7 @@ def test_act_backbone_lr():
|
||||
assert len(optimizer.param_groups[1]["params"]) == 20
|
||||
|
||||
|
||||
@pytest.mark.parametrize("policy_name", available_policies)
|
||||
@pytest.mark.parametrize("policy_name", AVAILABLE_POLICIES)
|
||||
def test_policy_defaults(dummy_dataset_metadata, policy_name: str):
|
||||
"""Check that the policy can be instantiated with defaults."""
|
||||
policy_cls = get_policy_class(policy_name)
|
||||
@@ -268,7 +269,7 @@ def test_policy_defaults(dummy_dataset_metadata, policy_name: str):
|
||||
policy_cls(policy_cfg)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("policy_name", available_policies)
|
||||
@pytest.mark.parametrize("policy_name", AVAILABLE_POLICIES)
|
||||
def test_save_and_load_pretrained(dummy_dataset_metadata, tmp_path, policy_name: str):
|
||||
policy_cls = get_policy_class(policy_name)
|
||||
policy_cfg = make_policy_config(policy_name)
|
||||
|
||||
Reference in New Issue
Block a user