mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-18 02:00:03 +00:00
feat(dependencies): minimal default tag install (#3362)
This commit is contained in:
@@ -23,7 +23,6 @@ import torch
|
||||
from gymnasium.envs.registration import register, registry as gym_registry
|
||||
from gymnasium.utils.env_checker import check_env
|
||||
|
||||
import lerobot
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.envs.configs import EnvConfig
|
||||
from lerobot.envs.factory import make_env, make_env_config
|
||||
@@ -36,9 +35,16 @@ from tests.utils import require_env
|
||||
|
||||
OBS_TYPES = ["state", "pixels", "pixels_agent_pos"]
|
||||
|
||||
ENV_TASK_PAIRS = [
|
||||
("aloha", "AlohaInsertion-v0"),
|
||||
("aloha", "AlohaTransferCube-v0"),
|
||||
("pusht", "PushT-v0"),
|
||||
]
|
||||
AVAILABLE_ENVS = ["aloha", "pusht"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("obs_type", OBS_TYPES)
|
||||
@pytest.mark.parametrize("env_name, env_task", lerobot.env_task_pairs)
|
||||
@pytest.mark.parametrize("env_name, env_task", ENV_TASK_PAIRS)
|
||||
@require_env
|
||||
def test_env(env_name, env_task, obs_type):
|
||||
if env_name == "aloha" and obs_type == "state":
|
||||
@@ -51,7 +57,7 @@ def test_env(env_name, env_task, obs_type):
|
||||
env.close()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("env_name", lerobot.available_envs)
|
||||
@pytest.mark.parametrize("env_name", AVAILABLE_ENVS)
|
||||
@require_env
|
||||
def test_factory(env_name):
|
||||
cfg = make_env_config(env_name)
|
||||
|
||||
Reference in New Issue
Block a user