refactor import fixes

This commit is contained in:
Steven Palma
2026-04-11 18:02:59 +02:00
parent d626964119
commit af0d72bd42
69 changed files with 306 additions and 339 deletions
+20 -6
View File
@@ -26,7 +26,6 @@ from PIL import Image
from safetensors.torch import load_file
from torchvision.transforms import v2
import lerobot
from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig
from lerobot.datasets import make_dataset
@@ -494,13 +493,28 @@ def test_tmp_mixed_deletion(tmp_path, empty_lerobot_dataset_factory):
# - [ ] remove old tests
ENV_DATASET_POLICY_TRIPLETS = [
("aloha", dataset, "act")
for dataset in [
"lerobot/aloha_sim_insertion_human",
"lerobot/aloha_sim_insertion_scripted",
"lerobot/aloha_sim_transfer_cube_human",
"lerobot/aloha_sim_transfer_cube_scripted",
"lerobot/aloha_sim_insertion_human_image",
"lerobot/aloha_sim_insertion_scripted_image",
"lerobot/aloha_sim_transfer_cube_human_image",
"lerobot/aloha_sim_transfer_cube_scripted_image",
]
] + [
("pusht", dataset, policy)
for dataset in ["lerobot/pusht", "lerobot/pusht_image"]
for policy in ["diffusion", "vqbet"]
]
@pytest.mark.parametrize(
"env_name, repo_id, policy_name",
# Single dataset
lerobot.env_dataset_policy_triplets,
# Multi-dataset
# TODO after fix multidataset
# + [("aloha", ["lerobot/aloha_sim_insertion_human", "lerobot/aloha_sim_transfer_cube_human"], "act")],
ENV_DATASET_POLICY_TRIPLETS,
)
def test_factory(env_name, repo_id, policy_name):
"""
+9 -3
View File
@@ -23,7 +23,6 @@ import torch
from gymnasium.envs.registration import register, registry as gym_registry
from gymnasium.utils.env_checker import check_env
import lerobot
from lerobot.configs.types import PolicyFeature
from lerobot.envs.configs import EnvConfig
from lerobot.envs.factory import make_env, make_env_config
@@ -36,9 +35,16 @@ from tests.utils import require_env
OBS_TYPES = ["state", "pixels", "pixels_agent_pos"]
ENV_TASK_PAIRS = [
("aloha", "AlohaInsertion-v0"),
("aloha", "AlohaTransferCube-v0"),
("pusht", "PushT-v0"),
]
AVAILABLE_ENVS = ["aloha", "pusht"]
@pytest.mark.parametrize("obs_type", OBS_TYPES)
@pytest.mark.parametrize("env_name, env_task", lerobot.env_task_pairs)
@pytest.mark.parametrize("env_name, env_task", ENV_TASK_PAIRS)
@require_env
def test_env(env_name, env_task, obs_type):
if env_name == "aloha" and obs_type == "state":
@@ -51,7 +57,7 @@ def test_env(env_name, env_task, obs_type):
env.close()
@pytest.mark.parametrize("env_name", lerobot.available_envs)
@pytest.mark.parametrize("env_name", AVAILABLE_ENVS)
@require_env
def test_factory(env_name):
cfg = make_env_config(env_name)
+5 -4
View File
@@ -23,7 +23,6 @@ import torch
from packaging import version
from safetensors.torch import load_file
from lerobot import available_policies
from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig
from lerobot.configs.types import FeatureType, PolicyFeature
@@ -49,6 +48,8 @@ from lerobot.utils.utils import cycle
from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats
from tests.utils import DEVICE, require_cpu, require_env, require_x86_64_kernel
AVAILABLE_POLICIES = ["act", "diffusion", "tdmpc", "vqbet"]
@pytest.fixture
def dummy_dataset_metadata(lerobot_dataset_metadata_factory, info_factory, tmp_path):
@@ -84,7 +85,7 @@ def dummy_dataset_metadata(lerobot_dataset_metadata_factory, info_factory, tmp_p
return ds_meta
@pytest.mark.parametrize("policy_name", available_policies)
@pytest.mark.parametrize("policy_name", AVAILABLE_POLICIES)
def test_get_policy_and_config_classes(policy_name: str):
"""Check that the correct policy and config classes are returned."""
policy_cls = get_policy_class(policy_name)
@@ -255,7 +256,7 @@ def test_act_backbone_lr():
assert len(optimizer.param_groups[1]["params"]) == 20
@pytest.mark.parametrize("policy_name", available_policies)
@pytest.mark.parametrize("policy_name", AVAILABLE_POLICIES)
def test_policy_defaults(dummy_dataset_metadata, policy_name: str):
"""Check that the policy can be instantiated with defaults."""
policy_cls = get_policy_class(policy_name)
@@ -268,7 +269,7 @@ def test_policy_defaults(dummy_dataset_metadata, policy_name: str):
policy_cls(policy_cfg)
@pytest.mark.parametrize("policy_name", available_policies)
@pytest.mark.parametrize("policy_name", AVAILABLE_POLICIES)
def test_save_and_load_pretrained(dummy_dataset_metadata, tmp_path, policy_name: str):
policy_cls = get_policy_class(policy_name)
policy_cfg = make_policy_config(policy_name)
+4 -41
View File
@@ -13,48 +13,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import gymnasium as gym
import pytest
import lerobot
from lerobot.policies.act.modeling_act import ACTPolicy
from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy
from lerobot.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
from lerobot.policies.vqbet.modeling_vqbet import VQBeTPolicy
from tests.utils import require_env
@pytest.mark.parametrize("env_name, task_name", lerobot.env_task_pairs)
@require_env
def test_available_env_task(env_name: str, task_name: list):
"""
This test verifies that all environments listed in `lerobot/__init__.py` can
be successfully imported if they're installed — and that their
`available_tasks_per_env` are valid.
"""
package_name = f"gym_{env_name}"
importlib.import_module(package_name)
gym_handle = f"{package_name}/{task_name}"
assert gym_handle in gym.envs.registry, gym_handle
def test_available_policies():
"""
This test verifies that the class attribute `name` for all policies is
consistent with those listed in `lerobot/__init__.py`.
"""
policy_classes = [ACTPolicy, DiffusionPolicy, TDMPCPolicy, VQBeTPolicy]
policies = [pol_cls.name for pol_cls in policy_classes]
assert set(policies) == set(lerobot.available_policies), policies
def test_print():
print(lerobot.available_envs)
print(lerobot.available_tasks_per_env)
print(lerobot.available_datasets)
print(lerobot.available_datasets_per_env)
print(lerobot.available_real_world_datasets)
print(lerobot.available_policies)
print(lerobot.available_policies_per_env)
def test_version():
"""Verify the package exposes a version string."""
assert isinstance(lerobot.__version__, str)
assert len(lerobot.__version__) > 0
+19 -4
View File
@@ -20,22 +20,37 @@ from functools import wraps
import pytest
import torch
from lerobot import available_cameras, available_motors, available_robots
from lerobot.utils.device_utils import auto_select_torch_device
from lerobot.utils.import_utils import is_package_available
DEVICE = os.environ.get("LEROBOT_TEST_DEVICE", str(auto_select_torch_device()))
AVAILABLE_ROBOTS = [
"koch",
"koch_bimanual",
"aloha",
"so100",
"so101",
]
AVAILABLE_CAMERAS = [
"opencv",
"intelrealsense",
]
AVAILABLE_MOTORS = [
"dynamixel",
"feetech",
]
TEST_ROBOT_TYPES = []
for robot_type in available_robots:
for robot_type in AVAILABLE_ROBOTS:
TEST_ROBOT_TYPES += [(robot_type, True), (robot_type, False)]
TEST_CAMERA_TYPES = []
for camera_type in available_cameras:
for camera_type in AVAILABLE_CAMERAS:
TEST_CAMERA_TYPES += [(camera_type, True), (camera_type, False)]
TEST_MOTOR_TYPES = []
for motor_type in available_motors:
for motor_type in AVAILABLE_MOTORS:
TEST_MOTOR_TYPES += [(motor_type, True), (motor_type, False)]
# Camera indices used for connecting physical cameras
+12 -12
View File
@@ -17,6 +17,16 @@
from pathlib import Path
from unittest.mock import Mock, patch
from lerobot.common.train_utils import (
get_step_checkpoint_dir,
get_step_identifier,
load_training_state,
load_training_step,
save_checkpoint,
save_training_state,
save_training_step,
update_last_checkpoint,
)
from lerobot.utils.constants import (
CHECKPOINTS_DIR,
LAST_CHECKPOINT_LINK,
@@ -27,16 +37,6 @@ from lerobot.utils.constants import (
TRAINING_STATE_DIR,
TRAINING_STEP,
)
from lerobot.utils.train_utils import (
get_step_checkpoint_dir,
get_step_identifier,
load_training_state,
load_training_step,
save_checkpoint,
save_training_state,
save_training_step,
update_last_checkpoint,
)
def test_get_step_identifier():
@@ -72,7 +72,7 @@ def test_update_last_checkpoint(tmp_path):
assert last_checkpoint.resolve() == checkpoint
@patch("lerobot.utils.train_utils.save_training_state")
@patch("lerobot.common.train_utils.save_training_state")
def test_save_checkpoint(mock_save_training_state, tmp_path, optimizer):
policy = Mock()
cfg = Mock()
@@ -82,7 +82,7 @@ def test_save_checkpoint(mock_save_training_state, tmp_path, optimizer):
mock_save_training_state.assert_called_once()
@patch("lerobot.utils.train_utils.save_training_state")
@patch("lerobot.common.train_utils.save_training_state")
def test_save_checkpoint_peft(mock_save_training_state, tmp_path, optimizer):
policy = Mock()
policy.config = Mock()