mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 17:20:05 +00:00
feat(dependencies): minimal default tag install (#3362)
This commit is contained in:
@@ -16,7 +16,11 @@
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import datasets
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
import datasets # noqa: E402
|
||||
import torch
|
||||
|
||||
from lerobot.datasets.aggregate import aggregate_datasets
|
||||
|
||||
@@ -18,6 +18,8 @@ from unittest.mock import patch
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
from lerobot.datasets.compute_stats import (
|
||||
RunningQuantileStats,
|
||||
_assert_type_and_shape,
|
||||
|
||||
@@ -20,6 +20,8 @@ import json
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import INFO_PATH
|
||||
from tests.fixtures.constants import DEFAULT_FPS, DUMMY_ROBOT_TYPE
|
||||
|
||||
@@ -15,8 +15,12 @@
|
||||
# limitations under the License.
|
||||
"""Contract tests for DatasetReader."""
|
||||
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
from lerobot.datasets.dataset_reader import DatasetReader
|
||||
from lerobot.datasets.video_utils import get_safe_default_codec
|
||||
from lerobot.utils.import_utils import get_safe_default_codec
|
||||
|
||||
# ── Loading ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@@ -21,6 +21,8 @@ import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
from lerobot.datasets.dataset_tools import (
|
||||
add_features,
|
||||
delete_episodes,
|
||||
|
||||
@@ -16,13 +16,16 @@
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
from datasets import Dataset # noqa: E402
|
||||
from huggingface_hub import DatasetCard
|
||||
|
||||
from lerobot.datasets.feature_utils import combine_feature_dicts
|
||||
from lerobot.datasets.io_utils import hf_transform_to_torch
|
||||
from lerobot.datasets.utils import create_lerobot_dataset_card
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES
|
||||
from lerobot.utils.feature_utils import combine_feature_dicts
|
||||
|
||||
|
||||
def calculate_episode_data_index(hf_dataset: Dataset) -> dict[str, torch.Tensor]:
|
||||
|
||||
@@ -23,6 +23,8 @@ import pytest
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
from lerobot.datasets.dataset_writer import _encode_video_worker
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import DEFAULT_IMAGE_PATH
|
||||
|
||||
@@ -21,21 +21,22 @@ from pathlib import Path
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
from PIL import Image
|
||||
from safetensors.torch import load_file
|
||||
from torchvision.transforms import v2
|
||||
|
||||
import lerobot
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.datasets.factory import make_dataset
|
||||
from lerobot.datasets.feature_utils import get_hf_features_from_features, hw_to_dataset_features
|
||||
from lerobot.datasets import make_dataset
|
||||
from lerobot.datasets.feature_utils import get_hf_features_from_features
|
||||
from lerobot.datasets.image_writer import image_array_to_pil_image
|
||||
from lerobot.datasets.io_utils import hf_transform_to_torch
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.multi_dataset import MultiLeRobotDataset
|
||||
from lerobot.datasets.transforms import ImageTransforms, ImageTransformsConfig
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
@@ -46,7 +47,9 @@ from lerobot.datasets.video_utils import VALID_VIDEO_CODECS
|
||||
from lerobot.envs.factory import make_env_config
|
||||
from lerobot.policies.factory import make_policy_config
|
||||
from lerobot.robots import make_robot_from_config
|
||||
from lerobot.transforms import ImageTransforms, ImageTransformsConfig
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, OBS_STR, REWARD
|
||||
from lerobot.utils.feature_utils import hw_to_dataset_features
|
||||
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
|
||||
from tests.mocks.mock_robot import MockRobotConfig
|
||||
from tests.utils import require_x86_64_kernel
|
||||
@@ -493,13 +496,28 @@ def test_tmp_mixed_deletion(tmp_path, empty_lerobot_dataset_factory):
|
||||
# - [ ] remove old tests
|
||||
|
||||
|
||||
ENV_DATASET_POLICY_TRIPLETS = [
|
||||
("aloha", dataset, "act")
|
||||
for dataset in [
|
||||
"lerobot/aloha_sim_insertion_human",
|
||||
"lerobot/aloha_sim_insertion_scripted",
|
||||
"lerobot/aloha_sim_transfer_cube_human",
|
||||
"lerobot/aloha_sim_transfer_cube_scripted",
|
||||
"lerobot/aloha_sim_insertion_human_image",
|
||||
"lerobot/aloha_sim_insertion_scripted_image",
|
||||
"lerobot/aloha_sim_transfer_cube_human_image",
|
||||
"lerobot/aloha_sim_transfer_cube_scripted_image",
|
||||
]
|
||||
] + [
|
||||
("pusht", dataset, policy)
|
||||
for dataset in ["lerobot/pusht", "lerobot/pusht_image"]
|
||||
for policy in ["diffusion", "vqbet"]
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"env_name, repo_id, policy_name",
|
||||
# Single dataset
|
||||
lerobot.env_dataset_policy_triplets,
|
||||
# Multi-dataset
|
||||
# TODO after fix multidataset
|
||||
# + [("aloha", ["lerobot/aloha_sim_insertion_human", "lerobot/aloha_sim_transfer_cube_human"], "act")],
|
||||
ENV_DATASET_POLICY_TRIPLETS,
|
||||
)
|
||||
def test_factory(env_name, repo_id, policy_name):
|
||||
"""
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
# limitations under the License.
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
from lerobot.datasets.feature_utils import (
|
||||
check_delta_timestamps,
|
||||
get_delta_indices,
|
||||
|
||||
@@ -21,7 +21,13 @@ from safetensors.torch import load_file
|
||||
from torchvision.transforms import v2
|
||||
from torchvision.transforms.v2 import functional as F # noqa: N812
|
||||
|
||||
from lerobot.datasets.transforms import (
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
from lerobot.scripts.lerobot_imgtransform_viz import (
|
||||
save_all_transforms,
|
||||
save_each_transform,
|
||||
)
|
||||
from lerobot.transforms import (
|
||||
ImageTransformConfig,
|
||||
ImageTransforms,
|
||||
ImageTransformsConfig,
|
||||
@@ -29,10 +35,6 @@ from lerobot.datasets.transforms import (
|
||||
SharpnessJitter,
|
||||
make_transform_from_config,
|
||||
)
|
||||
from lerobot.scripts.lerobot_imgtransform_viz import (
|
||||
save_all_transforms,
|
||||
save_each_transform,
|
||||
)
|
||||
from lerobot.utils.random_utils import seeded_context
|
||||
from tests.artifacts.image_transforms.save_image_transforms_to_safetensors import ARTIFACT_DIR
|
||||
from tests.utils import require_x86_64_kernel
|
||||
|
||||
@@ -20,6 +20,8 @@ import numpy as np
|
||||
import pytest
|
||||
from PIL import Image
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
from lerobot.datasets.image_writer import (
|
||||
AsyncImageWriter,
|
||||
image_array_to_pil_image,
|
||||
|
||||
@@ -25,6 +25,8 @@ from unittest.mock import Mock
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
import lerobot.datasets.dataset_metadata as dataset_metadata_module
|
||||
import lerobot.datasets.lerobot_dataset as lerobot_dataset_module
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
|
||||
@@ -19,6 +19,8 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
|
||||
|
||||
@@ -17,7 +17,10 @@ import logging
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
from datasets import Dataset # noqa: E402
|
||||
|
||||
from lerobot.datasets.io_utils import (
|
||||
hf_transform_to_torch,
|
||||
|
||||
@@ -17,6 +17,8 @@ import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||
from lerobot.datasets.utils import safe_shard
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
@@ -20,10 +20,13 @@ import queue
|
||||
import threading
|
||||
from unittest.mock import patch
|
||||
|
||||
import av
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("av", reason="av is required (install lerobot[dataset])")
|
||||
|
||||
import av # noqa: E402
|
||||
|
||||
from lerobot.datasets.video_utils import (
|
||||
VALID_VIDEO_CODECS,
|
||||
StreamingVideoEncoder,
|
||||
|
||||
@@ -23,8 +23,11 @@ These tests verify that:
|
||||
- Subtask handling gracefully handles missing data
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("pandas", reason="pandas is required (install lerobot[dataset])")
|
||||
|
||||
import pandas as pd # noqa: E402
|
||||
import torch
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
@@ -15,6 +15,8 @@
|
||||
# limitations under the License.
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
from lerobot.scripts.lerobot_dataset_viz import visualize_dataset
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user