feat(utils): extend import check util (#2820)

* refactor(utils): is_package_available now differentiate between pkg name and module name

* refactor(tests): update require_package decorator
This commit is contained in:
Steven Palma
2026-01-19 16:43:11 +01:00
committed by GitHub
parent fe068df711
commit 5286ef8439
7 changed files with 67 additions and 59 deletions
+17 -9
View File
@@ -21,12 +21,23 @@ from typing import Any
from draccus.choice_types import ChoiceRegistry
def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[bool, str] | bool:
"""Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/utils/import_utils.py
Check if the package spec exists and grab its version to avoid importing a local directory.
**Note:** this doesn't work for all packages.
def is_package_available(
pkg_name: str, import_name: str | None = None, return_version: bool = False
) -> tuple[bool, str] | bool:
"""
package_exists = importlib.util.find_spec(pkg_name) is not None
Check if the package spec exists and grab its version to avoid importing a local directory.
Args:
pkg_name: The name of the package as installed via pip (e.g. "python-can").
import_name: The actual name used to import the package (e.g. "can").
Defaults to pkg_name if not provided.
return_version: Whether to return the version string.
"""
if import_name is None:
import_name = pkg_name
# Check if the module spec exists using the import name
package_exists = importlib.util.find_spec(import_name) is not None
package_version = "N/A"
if package_exists:
try:
@@ -37,7 +48,7 @@ def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[b
# Fallback method: Only for "torch" and versions containing "dev"
if pkg_name == "torch":
try:
package = importlib.import_module(pkg_name)
package = importlib.import_module(import_name)
temp_version = getattr(package, "__version__", "N/A")
# Check if the version contains "dev"
if "dev" in temp_version:
@@ -48,9 +59,6 @@ def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[b
except ImportError:
# If the package can't be imported, it's not available
package_exists = False
elif pkg_name == "grpc":
package = importlib.import_module(pkg_name)
package_version = getattr(package, "__version__", "N/A")
else:
# For packages other than "torch", don't attempt the fallback and set as not available
package_exists = False