mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-17 17:50:09 +00:00
refactor import fixes
This commit is contained in:
@@ -26,7 +26,6 @@ from PIL import Image
|
||||
from safetensors.torch import load_file
|
||||
from torchvision.transforms import v2
|
||||
|
||||
import lerobot
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.datasets import make_dataset
|
||||
@@ -494,13 +493,28 @@ def test_tmp_mixed_deletion(tmp_path, empty_lerobot_dataset_factory):
|
||||
# - [ ] remove old tests
|
||||
|
||||
|
||||
ENV_DATASET_POLICY_TRIPLETS = [
|
||||
("aloha", dataset, "act")
|
||||
for dataset in [
|
||||
"lerobot/aloha_sim_insertion_human",
|
||||
"lerobot/aloha_sim_insertion_scripted",
|
||||
"lerobot/aloha_sim_transfer_cube_human",
|
||||
"lerobot/aloha_sim_transfer_cube_scripted",
|
||||
"lerobot/aloha_sim_insertion_human_image",
|
||||
"lerobot/aloha_sim_insertion_scripted_image",
|
||||
"lerobot/aloha_sim_transfer_cube_human_image",
|
||||
"lerobot/aloha_sim_transfer_cube_scripted_image",
|
||||
]
|
||||
] + [
|
||||
("pusht", dataset, policy)
|
||||
for dataset in ["lerobot/pusht", "lerobot/pusht_image"]
|
||||
for policy in ["diffusion", "vqbet"]
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"env_name, repo_id, policy_name",
|
||||
# Single dataset
|
||||
lerobot.env_dataset_policy_triplets,
|
||||
# Multi-dataset
|
||||
# TODO after fix multidataset
|
||||
# + [("aloha", ["lerobot/aloha_sim_insertion_human", "lerobot/aloha_sim_transfer_cube_human"], "act")],
|
||||
ENV_DATASET_POLICY_TRIPLETS,
|
||||
)
|
||||
def test_factory(env_name, repo_id, policy_name):
|
||||
"""
|
||||
|
||||
@@ -23,7 +23,6 @@ import torch
|
||||
from gymnasium.envs.registration import register, registry as gym_registry
|
||||
from gymnasium.utils.env_checker import check_env
|
||||
|
||||
import lerobot
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.envs.configs import EnvConfig
|
||||
from lerobot.envs.factory import make_env, make_env_config
|
||||
@@ -36,9 +35,16 @@ from tests.utils import require_env
|
||||
|
||||
OBS_TYPES = ["state", "pixels", "pixels_agent_pos"]
|
||||
|
||||
ENV_TASK_PAIRS = [
|
||||
("aloha", "AlohaInsertion-v0"),
|
||||
("aloha", "AlohaTransferCube-v0"),
|
||||
("pusht", "PushT-v0"),
|
||||
]
|
||||
AVAILABLE_ENVS = ["aloha", "pusht"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("obs_type", OBS_TYPES)
|
||||
@pytest.mark.parametrize("env_name, env_task", lerobot.env_task_pairs)
|
||||
@pytest.mark.parametrize("env_name, env_task", ENV_TASK_PAIRS)
|
||||
@require_env
|
||||
def test_env(env_name, env_task, obs_type):
|
||||
if env_name == "aloha" and obs_type == "state":
|
||||
@@ -51,7 +57,7 @@ def test_env(env_name, env_task, obs_type):
|
||||
env.close()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("env_name", lerobot.available_envs)
|
||||
@pytest.mark.parametrize("env_name", AVAILABLE_ENVS)
|
||||
@require_env
|
||||
def test_factory(env_name):
|
||||
cfg = make_env_config(env_name)
|
||||
|
||||
@@ -23,7 +23,6 @@ import torch
|
||||
from packaging import version
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from lerobot import available_policies
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
@@ -49,6 +48,8 @@ from lerobot.utils.utils import cycle
|
||||
from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats
|
||||
from tests.utils import DEVICE, require_cpu, require_env, require_x86_64_kernel
|
||||
|
||||
AVAILABLE_POLICIES = ["act", "diffusion", "tdmpc", "vqbet"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_dataset_metadata(lerobot_dataset_metadata_factory, info_factory, tmp_path):
|
||||
@@ -84,7 +85,7 @@ def dummy_dataset_metadata(lerobot_dataset_metadata_factory, info_factory, tmp_p
|
||||
return ds_meta
|
||||
|
||||
|
||||
@pytest.mark.parametrize("policy_name", available_policies)
|
||||
@pytest.mark.parametrize("policy_name", AVAILABLE_POLICIES)
|
||||
def test_get_policy_and_config_classes(policy_name: str):
|
||||
"""Check that the correct policy and config classes are returned."""
|
||||
policy_cls = get_policy_class(policy_name)
|
||||
@@ -255,7 +256,7 @@ def test_act_backbone_lr():
|
||||
assert len(optimizer.param_groups[1]["params"]) == 20
|
||||
|
||||
|
||||
@pytest.mark.parametrize("policy_name", available_policies)
|
||||
@pytest.mark.parametrize("policy_name", AVAILABLE_POLICIES)
|
||||
def test_policy_defaults(dummy_dataset_metadata, policy_name: str):
|
||||
"""Check that the policy can be instantiated with defaults."""
|
||||
policy_cls = get_policy_class(policy_name)
|
||||
@@ -268,7 +269,7 @@ def test_policy_defaults(dummy_dataset_metadata, policy_name: str):
|
||||
policy_cls(policy_cfg)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("policy_name", available_policies)
|
||||
@pytest.mark.parametrize("policy_name", AVAILABLE_POLICIES)
|
||||
def test_save_and_load_pretrained(dummy_dataset_metadata, tmp_path, policy_name: str):
|
||||
policy_cls = get_policy_class(policy_name)
|
||||
policy_cfg = make_policy_config(policy_name)
|
||||
|
||||
+4
-41
@@ -13,48 +13,11 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
|
||||
import gymnasium as gym
|
||||
import pytest
|
||||
|
||||
import lerobot
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
from lerobot.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
|
||||
from lerobot.policies.vqbet.modeling_vqbet import VQBeTPolicy
|
||||
from tests.utils import require_env
|
||||
|
||||
|
||||
@pytest.mark.parametrize("env_name, task_name", lerobot.env_task_pairs)
|
||||
@require_env
|
||||
def test_available_env_task(env_name: str, task_name: list):
|
||||
"""
|
||||
This test verifies that all environments listed in `lerobot/__init__.py` can
|
||||
be successfully imported — if they're installed — and that their
|
||||
`available_tasks_per_env` are valid.
|
||||
"""
|
||||
package_name = f"gym_{env_name}"
|
||||
importlib.import_module(package_name)
|
||||
gym_handle = f"{package_name}/{task_name}"
|
||||
assert gym_handle in gym.envs.registry, gym_handle
|
||||
|
||||
|
||||
def test_available_policies():
|
||||
"""
|
||||
This test verifies that the class attribute `name` for all policies is
|
||||
consistent with those listed in `lerobot/__init__.py`.
|
||||
"""
|
||||
policy_classes = [ACTPolicy, DiffusionPolicy, TDMPCPolicy, VQBeTPolicy]
|
||||
policies = [pol_cls.name for pol_cls in policy_classes]
|
||||
assert set(policies) == set(lerobot.available_policies), policies
|
||||
|
||||
|
||||
def test_print():
|
||||
print(lerobot.available_envs)
|
||||
print(lerobot.available_tasks_per_env)
|
||||
print(lerobot.available_datasets)
|
||||
print(lerobot.available_datasets_per_env)
|
||||
print(lerobot.available_real_world_datasets)
|
||||
print(lerobot.available_policies)
|
||||
print(lerobot.available_policies_per_env)
|
||||
def test_version():
|
||||
"""Verify the package exposes a version string."""
|
||||
assert isinstance(lerobot.__version__, str)
|
||||
assert len(lerobot.__version__) > 0
|
||||
|
||||
+19
-4
@@ -20,22 +20,37 @@ from functools import wraps
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot import available_cameras, available_motors, available_robots
|
||||
from lerobot.utils.device_utils import auto_select_torch_device
|
||||
from lerobot.utils.import_utils import is_package_available
|
||||
|
||||
DEVICE = os.environ.get("LEROBOT_TEST_DEVICE", str(auto_select_torch_device()))
|
||||
|
||||
AVAILABLE_ROBOTS = [
|
||||
"koch",
|
||||
"koch_bimanual",
|
||||
"aloha",
|
||||
"so100",
|
||||
"so101",
|
||||
]
|
||||
AVAILABLE_CAMERAS = [
|
||||
"opencv",
|
||||
"intelrealsense",
|
||||
]
|
||||
AVAILABLE_MOTORS = [
|
||||
"dynamixel",
|
||||
"feetech",
|
||||
]
|
||||
|
||||
TEST_ROBOT_TYPES = []
|
||||
for robot_type in available_robots:
|
||||
for robot_type in AVAILABLE_ROBOTS:
|
||||
TEST_ROBOT_TYPES += [(robot_type, True), (robot_type, False)]
|
||||
|
||||
TEST_CAMERA_TYPES = []
|
||||
for camera_type in available_cameras:
|
||||
for camera_type in AVAILABLE_CAMERAS:
|
||||
TEST_CAMERA_TYPES += [(camera_type, True), (camera_type, False)]
|
||||
|
||||
TEST_MOTOR_TYPES = []
|
||||
for motor_type in available_motors:
|
||||
for motor_type in AVAILABLE_MOTORS:
|
||||
TEST_MOTOR_TYPES += [(motor_type, True), (motor_type, False)]
|
||||
|
||||
# Camera indices used for connecting physical cameras
|
||||
|
||||
@@ -17,6 +17,16 @@
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from lerobot.common.train_utils import (
|
||||
get_step_checkpoint_dir,
|
||||
get_step_identifier,
|
||||
load_training_state,
|
||||
load_training_step,
|
||||
save_checkpoint,
|
||||
save_training_state,
|
||||
save_training_step,
|
||||
update_last_checkpoint,
|
||||
)
|
||||
from lerobot.utils.constants import (
|
||||
CHECKPOINTS_DIR,
|
||||
LAST_CHECKPOINT_LINK,
|
||||
@@ -27,16 +37,6 @@ from lerobot.utils.constants import (
|
||||
TRAINING_STATE_DIR,
|
||||
TRAINING_STEP,
|
||||
)
|
||||
from lerobot.utils.train_utils import (
|
||||
get_step_checkpoint_dir,
|
||||
get_step_identifier,
|
||||
load_training_state,
|
||||
load_training_step,
|
||||
save_checkpoint,
|
||||
save_training_state,
|
||||
save_training_step,
|
||||
update_last_checkpoint,
|
||||
)
|
||||
|
||||
|
||||
def test_get_step_identifier():
|
||||
@@ -72,7 +72,7 @@ def test_update_last_checkpoint(tmp_path):
|
||||
assert last_checkpoint.resolve() == checkpoint
|
||||
|
||||
|
||||
@patch("lerobot.utils.train_utils.save_training_state")
|
||||
@patch("lerobot.common.train_utils.save_training_state")
|
||||
def test_save_checkpoint(mock_save_training_state, tmp_path, optimizer):
|
||||
policy = Mock()
|
||||
cfg = Mock()
|
||||
@@ -82,7 +82,7 @@ def test_save_checkpoint(mock_save_training_state, tmp_path, optimizer):
|
||||
mock_save_training_state.assert_called_once()
|
||||
|
||||
|
||||
@patch("lerobot.utils.train_utils.save_training_state")
|
||||
@patch("lerobot.common.train_utils.save_training_state")
|
||||
def test_save_checkpoint_peft(mock_save_training_state, tmp_path, optimizer):
|
||||
policy = Mock()
|
||||
policy.config = Mock()
|
||||
|
||||
Reference in New Issue
Block a user