mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-18 02:00:03 +00:00
feat(dependencies): minimal default tag install (#3362)
This commit is contained in:
@@ -21,21 +21,22 @@ from pathlib import Path
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
from PIL import Image
|
||||
from safetensors.torch import load_file
|
||||
from torchvision.transforms import v2
|
||||
|
||||
import lerobot
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.datasets.factory import make_dataset
|
||||
from lerobot.datasets.feature_utils import get_hf_features_from_features, hw_to_dataset_features
|
||||
from lerobot.datasets import make_dataset
|
||||
from lerobot.datasets.feature_utils import get_hf_features_from_features
|
||||
from lerobot.datasets.image_writer import image_array_to_pil_image
|
||||
from lerobot.datasets.io_utils import hf_transform_to_torch
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.multi_dataset import MultiLeRobotDataset
|
||||
from lerobot.datasets.transforms import ImageTransforms, ImageTransformsConfig
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
@@ -46,7 +47,9 @@ from lerobot.datasets.video_utils import VALID_VIDEO_CODECS
|
||||
from lerobot.envs.factory import make_env_config
|
||||
from lerobot.policies.factory import make_policy_config
|
||||
from lerobot.robots import make_robot_from_config
|
||||
from lerobot.transforms import ImageTransforms, ImageTransformsConfig
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, OBS_STR, REWARD
|
||||
from lerobot.utils.feature_utils import hw_to_dataset_features
|
||||
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
|
||||
from tests.mocks.mock_robot import MockRobotConfig
|
||||
from tests.utils import require_x86_64_kernel
|
||||
@@ -493,13 +496,28 @@ def test_tmp_mixed_deletion(tmp_path, empty_lerobot_dataset_factory):
|
||||
# - [ ] remove old tests
|
||||
|
||||
|
||||
ENV_DATASET_POLICY_TRIPLETS = [
|
||||
("aloha", dataset, "act")
|
||||
for dataset in [
|
||||
"lerobot/aloha_sim_insertion_human",
|
||||
"lerobot/aloha_sim_insertion_scripted",
|
||||
"lerobot/aloha_sim_transfer_cube_human",
|
||||
"lerobot/aloha_sim_transfer_cube_scripted",
|
||||
"lerobot/aloha_sim_insertion_human_image",
|
||||
"lerobot/aloha_sim_insertion_scripted_image",
|
||||
"lerobot/aloha_sim_transfer_cube_human_image",
|
||||
"lerobot/aloha_sim_transfer_cube_scripted_image",
|
||||
]
|
||||
] + [
|
||||
("pusht", dataset, policy)
|
||||
for dataset in ["lerobot/pusht", "lerobot/pusht_image"]
|
||||
for policy in ["diffusion", "vqbet"]
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"env_name, repo_id, policy_name",
|
||||
# Single dataset
|
||||
lerobot.env_dataset_policy_triplets,
|
||||
# Multi-dataset
|
||||
# TODO after fix multidataset
|
||||
# + [("aloha", ["lerobot/aloha_sim_insertion_human", "lerobot/aloha_sim_transfer_cube_human"], "act")],
|
||||
ENV_DATASET_POLICY_TRIPLETS,
|
||||
)
|
||||
def test_factory(env_name, repo_id, policy_name):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user