feat(utils): extend import check util (#2820)

* refactor(utils): is_package_available now differentiate between pkg name and module name

* refactor(tests): update require_package decorator
This commit is contained in:
Steven Palma
2026-01-19 16:43:11 +01:00
committed by GitHub
parent fe068df711
commit 5286ef8439
7 changed files with 67 additions and 59 deletions
+17 -9
View File
@@ -21,12 +21,23 @@ from typing import Any
from draccus.choice_types import ChoiceRegistry
def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[bool, str] | bool:
"""Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/utils/import_utils.py
Check if the package spec exists and grab its version to avoid importing a local directory.
**Note:** this doesn't work for all packages.
def is_package_available(
pkg_name: str, import_name: str | None = None, return_version: bool = False
) -> tuple[bool, str] | bool:
"""
package_exists = importlib.util.find_spec(pkg_name) is not None
Check if the package spec exists and grab its version to avoid importing a local directory.
Args:
pkg_name: The name of the package as installed via pip (e.g. "python-can").
import_name: The actual name used to import the package (e.g. "can").
Defaults to pkg_name if not provided.
return_version: Whether to return the version string.
"""
if import_name is None:
import_name = pkg_name
# Check if the module spec exists using the import name
package_exists = importlib.util.find_spec(import_name) is not None
package_version = "N/A"
if package_exists:
try:
@@ -37,7 +48,7 @@ def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[b
# Fallback method: Only for "torch" and versions containing "dev"
if pkg_name == "torch":
try:
package = importlib.import_module(pkg_name)
package = importlib.import_module(import_name)
temp_version = getattr(package, "__version__", "N/A")
# Check if the version contains "dev"
if "dev" in temp_version:
@@ -48,9 +59,6 @@ def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[b
except ImportError:
# If the package can't be imported, it's not available
package_exists = False
elif pkg_name == "grpc":
package = importlib.import_module(pkg_name)
package_version = getattr(package, "__version__", "N/A")
else:
# For packages other than "torch", don't attempt the fallback and set as not available
package_exists = False
+1 -1
View File
@@ -62,7 +62,7 @@ class MockPolicy:
@pytest.fixture
@require_package("grpc")
@require_package("grpcio", "grpc")
def policy_server():
"""Fresh `PolicyServer` instance with a stubbed-out policy model."""
# Import only when the test actually runs (after decorator check)
+5 -5
View File
@@ -64,7 +64,7 @@ def close_service_stub(channel, server):
server.stop(None)
@require_package("grpc")
@require_package("grpcio", "grpc")
def test_establish_learner_connection_success():
from lerobot.rl.actor import establish_learner_connection
@@ -81,7 +81,7 @@ def test_establish_learner_connection_success():
close_service_stub(channel, server)
@require_package("grpc")
@require_package("grpcio", "grpc")
def test_establish_learner_connection_failure():
from lerobot.rl.actor import establish_learner_connection
@@ -100,7 +100,7 @@ def test_establish_learner_connection_failure():
close_service_stub(channel, server)
@require_package("grpc")
@require_package("grpcio", "grpc")
def test_push_transitions_to_transport_queue():
from lerobot.rl.actor import push_transitions_to_transport_queue
from lerobot.transport.utils import bytes_to_transitions
@@ -135,7 +135,7 @@ def test_push_transitions_to_transport_queue():
assert_transitions_equal(deserialized_transition, transitions[i])
@require_package("grpc")
@require_package("grpcio", "grpc")
@pytest.mark.timeout(3) # force cross-platform watchdog
def test_transitions_stream():
from lerobot.rl.actor import transitions_stream
@@ -167,7 +167,7 @@ def test_transitions_stream():
assert streamed_data[2].data == b"transition_data_3"
@require_package("grpc")
@require_package("grpcio", "grpc")
@pytest.mark.timeout(3) # force cross-platform watchdog
def test_interactions_stream():
from lerobot.rl.actor import interactions_stream
+3 -3
View File
@@ -88,7 +88,7 @@ def cfg():
return cfg
@require_package("grpc")
@require_package("grpcio", "grpc")
@pytest.mark.timeout(10) # force cross-platform watchdog
def test_end_to_end_transitions_flow(cfg):
from lerobot.rl.actor import (
@@ -150,7 +150,7 @@ def test_end_to_end_transitions_flow(cfg):
assert_transitions_equal(transition, input_transitions[i])
@require_package("grpc")
@require_package("grpcio", "grpc")
@pytest.mark.timeout(10)
def test_end_to_end_interactions_flow(cfg):
from lerobot.rl.actor import (
@@ -223,7 +223,7 @@ def test_end_to_end_interactions_flow(cfg):
assert received == expected
@require_package("grpc")
@require_package("grpcio", "grpc")
@pytest.mark.parametrize("data_size", ["small", "large"])
@pytest.mark.timeout(10)
def test_end_to_end_parameters_flow(cfg, data_size):
+8 -8
View File
@@ -39,7 +39,7 @@ def learner_service_stub():
close_learner_service_stub(channel, server)
@require_package("grpc")
@require_package("grpcio", "grpc")
def create_learner_service_stub(
shutdown_event: Event,
parameters_queue: Queue,
@@ -75,7 +75,7 @@ def create_learner_service_stub(
return services_pb2_grpc.LearnerServiceStub(channel), channel, server
@require_package("grpc")
@require_package("grpcio", "grpc")
def close_learner_service_stub(channel, server):
channel.close()
server.stop(None)
@@ -91,7 +91,7 @@ def test_ready_method(learner_service_stub):
assert response == services_pb2.Empty()
@require_package("grpc")
@require_package("grpcio", "grpc")
@pytest.mark.timeout(3) # force cross-platform watchdog
def test_send_interactions():
from lerobot.transport import services_pb2
@@ -135,7 +135,7 @@ def test_send_interactions():
assert interactions == [b"123", b"4", b"5", b"678"]
@require_package("grpc")
@require_package("grpcio", "grpc")
@pytest.mark.timeout(3) # force cross-platform watchdog
def test_send_transitions():
from lerobot.transport import services_pb2
@@ -181,7 +181,7 @@ def test_send_transitions():
assert transitions == [b"transition_1transition_2transition_3", b"batch_1batch_2"]
@require_package("grpc")
@require_package("grpcio", "grpc")
@pytest.mark.timeout(3) # force cross-platform watchdog
def test_send_transitions_empty_stream():
from lerobot.transport import services_pb2
@@ -209,7 +209,7 @@ def test_send_transitions_empty_stream():
assert transitions_queue.empty()
@require_package("grpc")
@require_package("grpcio", "grpc")
@pytest.mark.timeout(10) # force cross-platform watchdog
def test_stream_parameters():
import time
@@ -267,7 +267,7 @@ def test_stream_parameters():
assert time_diff == pytest.approx(seconds_between_pushes, abs=0.1)
@require_package("grpc")
@require_package("grpcio", "grpc")
@pytest.mark.timeout(3) # force cross-platform watchdog
def test_stream_parameters_with_shutdown():
from lerobot.transport import services_pb2
@@ -319,7 +319,7 @@ def test_stream_parameters_with_shutdown():
assert received_params == [b"param_batch_1", b"stop"]
@require_package("grpc")
@require_package("grpcio", "grpc")
@pytest.mark.timeout(3) # force cross-platform watchdog
def test_stream_parameters_waits_and_retries_on_empty_queue():
import threading
+31 -31
View File
@@ -26,7 +26,7 @@ from lerobot.utils.transition import Transition
from tests.utils import require_cuda, require_package
@require_package("grpc")
@require_package("grpcio", "grpc")
def test_bytes_buffer_size_empty_buffer():
from lerobot.transport.utils import bytes_buffer_size
@@ -37,7 +37,7 @@ def test_bytes_buffer_size_empty_buffer():
assert buffer.tell() == 0
@require_package("grpc")
@require_package("grpcio", "grpc")
def test_bytes_buffer_size_small_buffer():
from lerobot.transport.utils import bytes_buffer_size
@@ -47,7 +47,7 @@ def test_bytes_buffer_size_small_buffer():
assert buffer.tell() == 0
@require_package("grpc")
@require_package("grpcio", "grpc")
def test_bytes_buffer_size_large_buffer():
from lerobot.transport.utils import CHUNK_SIZE, bytes_buffer_size
@@ -58,7 +58,7 @@ def test_bytes_buffer_size_large_buffer():
assert buffer.tell() == 0
@require_package("grpc")
@require_package("grpcio", "grpc")
def test_send_bytes_in_chunks_empty_data():
from lerobot.transport.utils import send_bytes_in_chunks, services_pb2
@@ -68,7 +68,7 @@ def test_send_bytes_in_chunks_empty_data():
assert len(chunks) == 0
@require_package("grpc")
@require_package("grpcio", "grpc")
def test_single_chunk_small_data():
from lerobot.transport.utils import send_bytes_in_chunks, services_pb2
@@ -82,7 +82,7 @@ def test_single_chunk_small_data():
assert chunks[0].transfer_state == services_pb2.TransferState.TRANSFER_END
@require_package("grpc")
@require_package("grpcio", "grpc")
def test_not_silent_mode():
from lerobot.transport.utils import send_bytes_in_chunks, services_pb2
@@ -94,7 +94,7 @@ def test_not_silent_mode():
assert chunks[0].data == b"Some data"
@require_package("grpc")
@require_package("grpcio", "grpc")
def test_send_bytes_in_chunks_large_data():
from lerobot.transport.utils import CHUNK_SIZE, send_bytes_in_chunks, services_pb2
@@ -111,7 +111,7 @@ def test_send_bytes_in_chunks_large_data():
assert chunks[2].transfer_state == services_pb2.TransferState.TRANSFER_END
@require_package("grpc")
@require_package("grpcio", "grpc")
def test_send_bytes_in_chunks_large_data_with_exact_chunk_size():
from lerobot.transport.utils import CHUNK_SIZE, send_bytes_in_chunks, services_pb2
@@ -124,7 +124,7 @@ def test_send_bytes_in_chunks_large_data_with_exact_chunk_size():
assert chunks[0].transfer_state == services_pb2.TransferState.TRANSFER_END
@require_package("grpc")
@require_package("grpcio", "grpc")
def test_receive_bytes_in_chunks_empty_data():
from lerobot.transport.utils import receive_bytes_in_chunks
@@ -138,7 +138,7 @@ def test_receive_bytes_in_chunks_empty_data():
assert queue.empty()
@require_package("grpc")
@require_package("grpcio", "grpc")
def test_receive_bytes_in_chunks_single_chunk():
from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2
@@ -157,7 +157,7 @@ def test_receive_bytes_in_chunks_single_chunk():
assert queue.empty()
@require_package("grpc")
@require_package("grpcio", "grpc")
def test_receive_bytes_in_chunks_single_not_end_chunk():
from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2
@@ -175,7 +175,7 @@ def test_receive_bytes_in_chunks_single_not_end_chunk():
assert queue.empty()
@require_package("grpc")
@require_package("grpcio", "grpc")
def test_receive_bytes_in_chunks_multiple_chunks():
from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2
@@ -199,7 +199,7 @@ def test_receive_bytes_in_chunks_multiple_chunks():
assert queue.empty()
@require_package("grpc")
@require_package("grpcio", "grpc")
def test_receive_bytes_in_chunks_multiple_messages():
from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2
@@ -235,7 +235,7 @@ def test_receive_bytes_in_chunks_multiple_messages():
assert queue.empty()
@require_package("grpc")
@require_package("grpcio", "grpc")
def test_receive_bytes_in_chunks_shutdown_during_receive():
from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2
@@ -259,7 +259,7 @@ def test_receive_bytes_in_chunks_shutdown_during_receive():
assert queue.empty()
@require_package("grpc")
@require_package("grpcio", "grpc")
def test_receive_bytes_in_chunks_only_begin_chunk():
from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2
@@ -279,7 +279,7 @@ def test_receive_bytes_in_chunks_only_begin_chunk():
assert queue.empty()
@require_package("grpc")
@require_package("grpcio", "grpc")
def test_receive_bytes_in_chunks_missing_begin():
from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2
@@ -303,7 +303,7 @@ def test_receive_bytes_in_chunks_missing_begin():
# Tests for state_to_bytes and bytes_to_state_dict
@require_package("grpc")
@require_package("grpcio", "grpc")
def test_state_to_bytes_empty_dict():
from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes
@@ -314,7 +314,7 @@ def test_state_to_bytes_empty_dict():
assert reconstructed == state_dict
@require_package("grpc")
@require_package("grpcio", "grpc")
def test_bytes_to_state_dict_empty_data():
from lerobot.transport.utils import bytes_to_state_dict
@@ -323,7 +323,7 @@ def test_bytes_to_state_dict_empty_data():
bytes_to_state_dict(b"")
@require_package("grpc")
@require_package("grpcio", "grpc")
def test_state_to_bytes_simple_dict():
from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes
@@ -347,7 +347,7 @@ def test_state_to_bytes_simple_dict():
assert torch.allclose(state_dict[key], reconstructed[key])
@require_package("grpc")
@require_package("grpcio", "grpc")
def test_state_to_bytes_various_dtypes():
from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes
@@ -372,7 +372,7 @@ def test_state_to_bytes_various_dtypes():
assert torch.allclose(state_dict[key], reconstructed[key])
@require_package("grpc")
@require_package("grpcio", "grpc")
def test_bytes_to_state_dict_invalid_data():
from lerobot.transport.utils import bytes_to_state_dict
@@ -382,7 +382,7 @@ def test_bytes_to_state_dict_invalid_data():
@require_cuda
@require_package("grpc")
@require_package("grpcio", "grpc")
def test_state_to_bytes_various_dtypes_cuda():
from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes
@@ -407,7 +407,7 @@ def test_state_to_bytes_various_dtypes_cuda():
assert torch.allclose(state_dict[key], reconstructed[key])
@require_package("grpc")
@require_package("grpcio", "grpc")
def test_python_object_to_bytes_none():
from lerobot.transport.utils import bytes_to_python_object, python_object_to_bytes
@@ -439,7 +439,7 @@ def test_python_object_to_bytes_none():
(1, 2, 3),
],
)
@require_package("grpc")
@require_package("grpcio", "grpc")
def test_python_object_to_bytes_simple_types(obj):
from lerobot.transport.utils import bytes_to_python_object, python_object_to_bytes
@@ -450,7 +450,7 @@ def test_python_object_to_bytes_simple_types(obj):
assert type(reconstructed) is type(obj)
@require_package("grpc")
@require_package("grpcio", "grpc")
def test_python_object_to_bytes_with_tensors():
from lerobot.transport.utils import bytes_to_python_object, python_object_to_bytes
@@ -475,7 +475,7 @@ def test_python_object_to_bytes_with_tensors():
assert torch.equal(obj["nested"]["tensor2"], reconstructed["nested"]["tensor2"])
@require_package("grpc")
@require_package("grpcio", "grpc")
def test_transitions_to_bytes_empty_list():
from lerobot.transport.utils import bytes_to_transitions, transitions_to_bytes
@@ -487,7 +487,7 @@ def test_transitions_to_bytes_empty_list():
assert isinstance(reconstructed, list)
@require_package("grpc")
@require_package("grpcio", "grpc")
def test_transitions_to_bytes_single_transition():
from lerobot.transport.utils import bytes_to_transitions, transitions_to_bytes
@@ -509,7 +509,7 @@ def test_transitions_to_bytes_single_transition():
assert_transitions_equal(transitions[0], reconstructed[0])
@require_package("grpc")
@require_package("grpcio", "grpc")
def assert_transitions_equal(t1: Transition, t2: Transition):
"""Helper to assert two transitions are equal."""
assert_observation_equal(t1["state"], t2["state"])
@@ -519,7 +519,7 @@ def assert_transitions_equal(t1: Transition, t2: Transition):
assert_observation_equal(t1["next_state"], t2["next_state"])
@require_package("grpc")
@require_package("grpcio", "grpc")
def assert_observation_equal(o1: dict, o2: dict):
"""Helper to assert two observations are equal."""
assert set(o1.keys()) == set(o2.keys())
@@ -527,7 +527,7 @@ def assert_observation_equal(o1: dict, o2: dict):
assert torch.allclose(o1[key], o2[key])
@require_package("grpc")
@require_package("grpcio", "grpc")
def test_transitions_to_bytes_multiple_transitions():
from lerobot.transport.utils import bytes_to_transitions, transitions_to_bytes
@@ -551,7 +551,7 @@ def test_transitions_to_bytes_multiple_transitions():
assert_transitions_equal(original, reconstructed_item)
@require_package("grpc")
@require_package("grpcio", "grpc")
def test_receive_bytes_in_chunks_unknown_state():
from lerobot.transport.utils import receive_bytes_in_chunks
+2 -2
View File
@@ -167,7 +167,7 @@ def require_package_arg(func):
return wrapper
def require_package(package_name):
def require_package(package_name, import_name=None):
"""
Decorator that skips the test if the specified package is not installed.
"""
@@ -175,7 +175,7 @@ def require_package(package_name):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
if not is_package_available(package_name):
if not is_package_available(pkg_name=package_name, import_name=import_name):
pytest.skip(f"{package_name} not installed")
return func(*args, **kwargs)