mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-18 02:00:03 +00:00
feat(utils): extend import check util (#2820)
* refactor(utils): is_package_available now differentiate between pkg name and module name * refactor(tests): update require_package decorator
This commit is contained in:
@@ -21,12 +21,23 @@ from typing import Any
|
||||
from draccus.choice_types import ChoiceRegistry
|
||||
|
||||
|
||||
def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[bool, str] | bool:
|
||||
"""Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/utils/import_utils.py
|
||||
Check if the package spec exists and grab its version to avoid importing a local directory.
|
||||
**Note:** this doesn't work for all packages.
|
||||
def is_package_available(
|
||||
pkg_name: str, import_name: str | None = None, return_version: bool = False
|
||||
) -> tuple[bool, str] | bool:
|
||||
"""
|
||||
package_exists = importlib.util.find_spec(pkg_name) is not None
|
||||
Check if the package spec exists and grab its version to avoid importing a local directory.
|
||||
|
||||
Args:
|
||||
pkg_name: The name of the package as installed via pip (e.g. "python-can").
|
||||
import_name: The actual name used to import the package (e.g. "can").
|
||||
Defaults to pkg_name if not provided.
|
||||
return_version: Whether to return the version string.
|
||||
"""
|
||||
if import_name is None:
|
||||
import_name = pkg_name
|
||||
|
||||
# Check if the module spec exists using the import name
|
||||
package_exists = importlib.util.find_spec(import_name) is not None
|
||||
package_version = "N/A"
|
||||
if package_exists:
|
||||
try:
|
||||
@@ -37,7 +48,7 @@ def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[b
|
||||
# Fallback method: Only for "torch" and versions containing "dev"
|
||||
if pkg_name == "torch":
|
||||
try:
|
||||
package = importlib.import_module(pkg_name)
|
||||
package = importlib.import_module(import_name)
|
||||
temp_version = getattr(package, "__version__", "N/A")
|
||||
# Check if the version contains "dev"
|
||||
if "dev" in temp_version:
|
||||
@@ -48,9 +59,6 @@ def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[b
|
||||
except ImportError:
|
||||
# If the package can't be imported, it's not available
|
||||
package_exists = False
|
||||
elif pkg_name == "grpc":
|
||||
package = importlib.import_module(pkg_name)
|
||||
package_version = getattr(package, "__version__", "N/A")
|
||||
else:
|
||||
# For packages other than "torch", don't attempt the fallback and set as not available
|
||||
package_exists = False
|
||||
|
||||
Reference in New Issue
Block a user