feat(utils): extend import check util (#2820)

* refactor(utils): is_package_available now differentiate between pkg name and module name

* refactor(tests): update require_package decorator
This commit is contained in:
Steven Palma
2026-01-19 16:43:11 +01:00
committed by GitHub
parent fe068df711
commit 5286ef8439
7 changed files with 67 additions and 59 deletions
+17 -9
View File
@@ -21,12 +21,23 @@ from typing import Any
from draccus.choice_types import ChoiceRegistry from draccus.choice_types import ChoiceRegistry
def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[bool, str] | bool: def is_package_available(
"""Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/utils/import_utils.py pkg_name: str, import_name: str | None = None, return_version: bool = False
Check if the package spec exists and grab its version to avoid importing a local directory. ) -> tuple[bool, str] | bool:
**Note:** this doesn't work for all packages.
""" """
package_exists = importlib.util.find_spec(pkg_name) is not None Check if the package spec exists and grab its version to avoid importing a local directory.
Args:
pkg_name: The name of the package as installed via pip (e.g. "python-can").
import_name: The actual name used to import the package (e.g. "can").
Defaults to pkg_name if not provided.
return_version: Whether to return the version string.
"""
if import_name is None:
import_name = pkg_name
# Check if the module spec exists using the import name
package_exists = importlib.util.find_spec(import_name) is not None
package_version = "N/A" package_version = "N/A"
if package_exists: if package_exists:
try: try:
@@ -37,7 +48,7 @@ def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[b
# Fallback method: Only for "torch" and versions containing "dev" # Fallback method: Only for "torch" and versions containing "dev"
if pkg_name == "torch": if pkg_name == "torch":
try: try:
package = importlib.import_module(pkg_name) package = importlib.import_module(import_name)
temp_version = getattr(package, "__version__", "N/A") temp_version = getattr(package, "__version__", "N/A")
# Check if the version contains "dev" # Check if the version contains "dev"
if "dev" in temp_version: if "dev" in temp_version:
@@ -48,9 +59,6 @@ def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[b
except ImportError: except ImportError:
# If the package can't be imported, it's not available # If the package can't be imported, it's not available
package_exists = False package_exists = False
elif pkg_name == "grpc":
package = importlib.import_module(pkg_name)
package_version = getattr(package, "__version__", "N/A")
else: else:
# For packages other than "torch", don't attempt the fallback and set as not available # For packages other than "torch", don't attempt the fallback and set as not available
package_exists = False package_exists = False
+1 -1
View File
@@ -62,7 +62,7 @@ class MockPolicy:
@pytest.fixture @pytest.fixture
@require_package("grpc") @require_package("grpcio", "grpc")
def policy_server(): def policy_server():
"""Fresh `PolicyServer` instance with a stubbed-out policy model.""" """Fresh `PolicyServer` instance with a stubbed-out policy model."""
# Import only when the test actually runs (after decorator check) # Import only when the test actually runs (after decorator check)
+5 -5
View File
@@ -64,7 +64,7 @@ def close_service_stub(channel, server):
server.stop(None) server.stop(None)
@require_package("grpc") @require_package("grpcio", "grpc")
def test_establish_learner_connection_success(): def test_establish_learner_connection_success():
from lerobot.rl.actor import establish_learner_connection from lerobot.rl.actor import establish_learner_connection
@@ -81,7 +81,7 @@ def test_establish_learner_connection_success():
close_service_stub(channel, server) close_service_stub(channel, server)
@require_package("grpc") @require_package("grpcio", "grpc")
def test_establish_learner_connection_failure(): def test_establish_learner_connection_failure():
from lerobot.rl.actor import establish_learner_connection from lerobot.rl.actor import establish_learner_connection
@@ -100,7 +100,7 @@ def test_establish_learner_connection_failure():
close_service_stub(channel, server) close_service_stub(channel, server)
@require_package("grpc") @require_package("grpcio", "grpc")
def test_push_transitions_to_transport_queue(): def test_push_transitions_to_transport_queue():
from lerobot.rl.actor import push_transitions_to_transport_queue from lerobot.rl.actor import push_transitions_to_transport_queue
from lerobot.transport.utils import bytes_to_transitions from lerobot.transport.utils import bytes_to_transitions
@@ -135,7 +135,7 @@ def test_push_transitions_to_transport_queue():
assert_transitions_equal(deserialized_transition, transitions[i]) assert_transitions_equal(deserialized_transition, transitions[i])
@require_package("grpc") @require_package("grpcio", "grpc")
@pytest.mark.timeout(3) # force cross-platform watchdog @pytest.mark.timeout(3) # force cross-platform watchdog
def test_transitions_stream(): def test_transitions_stream():
from lerobot.rl.actor import transitions_stream from lerobot.rl.actor import transitions_stream
@@ -167,7 +167,7 @@ def test_transitions_stream():
assert streamed_data[2].data == b"transition_data_3" assert streamed_data[2].data == b"transition_data_3"
@require_package("grpc") @require_package("grpcio", "grpc")
@pytest.mark.timeout(3) # force cross-platform watchdog @pytest.mark.timeout(3) # force cross-platform watchdog
def test_interactions_stream(): def test_interactions_stream():
from lerobot.rl.actor import interactions_stream from lerobot.rl.actor import interactions_stream
+3 -3
View File
@@ -88,7 +88,7 @@ def cfg():
return cfg return cfg
@require_package("grpc") @require_package("grpcio", "grpc")
@pytest.mark.timeout(10) # force cross-platform watchdog @pytest.mark.timeout(10) # force cross-platform watchdog
def test_end_to_end_transitions_flow(cfg): def test_end_to_end_transitions_flow(cfg):
from lerobot.rl.actor import ( from lerobot.rl.actor import (
@@ -150,7 +150,7 @@ def test_end_to_end_transitions_flow(cfg):
assert_transitions_equal(transition, input_transitions[i]) assert_transitions_equal(transition, input_transitions[i])
@require_package("grpc") @require_package("grpcio", "grpc")
@pytest.mark.timeout(10) @pytest.mark.timeout(10)
def test_end_to_end_interactions_flow(cfg): def test_end_to_end_interactions_flow(cfg):
from lerobot.rl.actor import ( from lerobot.rl.actor import (
@@ -223,7 +223,7 @@ def test_end_to_end_interactions_flow(cfg):
assert received == expected assert received == expected
@require_package("grpc") @require_package("grpcio", "grpc")
@pytest.mark.parametrize("data_size", ["small", "large"]) @pytest.mark.parametrize("data_size", ["small", "large"])
@pytest.mark.timeout(10) @pytest.mark.timeout(10)
def test_end_to_end_parameters_flow(cfg, data_size): def test_end_to_end_parameters_flow(cfg, data_size):
+8 -8
View File
@@ -39,7 +39,7 @@ def learner_service_stub():
close_learner_service_stub(channel, server) close_learner_service_stub(channel, server)
@require_package("grpc") @require_package("grpcio", "grpc")
def create_learner_service_stub( def create_learner_service_stub(
shutdown_event: Event, shutdown_event: Event,
parameters_queue: Queue, parameters_queue: Queue,
@@ -75,7 +75,7 @@ def create_learner_service_stub(
return services_pb2_grpc.LearnerServiceStub(channel), channel, server return services_pb2_grpc.LearnerServiceStub(channel), channel, server
@require_package("grpc") @require_package("grpcio", "grpc")
def close_learner_service_stub(channel, server): def close_learner_service_stub(channel, server):
channel.close() channel.close()
server.stop(None) server.stop(None)
@@ -91,7 +91,7 @@ def test_ready_method(learner_service_stub):
assert response == services_pb2.Empty() assert response == services_pb2.Empty()
@require_package("grpc") @require_package("grpcio", "grpc")
@pytest.mark.timeout(3) # force cross-platform watchdog @pytest.mark.timeout(3) # force cross-platform watchdog
def test_send_interactions(): def test_send_interactions():
from lerobot.transport import services_pb2 from lerobot.transport import services_pb2
@@ -135,7 +135,7 @@ def test_send_interactions():
assert interactions == [b"123", b"4", b"5", b"678"] assert interactions == [b"123", b"4", b"5", b"678"]
@require_package("grpc") @require_package("grpcio", "grpc")
@pytest.mark.timeout(3) # force cross-platform watchdog @pytest.mark.timeout(3) # force cross-platform watchdog
def test_send_transitions(): def test_send_transitions():
from lerobot.transport import services_pb2 from lerobot.transport import services_pb2
@@ -181,7 +181,7 @@ def test_send_transitions():
assert transitions == [b"transition_1transition_2transition_3", b"batch_1batch_2"] assert transitions == [b"transition_1transition_2transition_3", b"batch_1batch_2"]
@require_package("grpc") @require_package("grpcio", "grpc")
@pytest.mark.timeout(3) # force cross-platform watchdog @pytest.mark.timeout(3) # force cross-platform watchdog
def test_send_transitions_empty_stream(): def test_send_transitions_empty_stream():
from lerobot.transport import services_pb2 from lerobot.transport import services_pb2
@@ -209,7 +209,7 @@ def test_send_transitions_empty_stream():
assert transitions_queue.empty() assert transitions_queue.empty()
@require_package("grpc") @require_package("grpcio", "grpc")
@pytest.mark.timeout(10) # force cross-platform watchdog @pytest.mark.timeout(10) # force cross-platform watchdog
def test_stream_parameters(): def test_stream_parameters():
import time import time
@@ -267,7 +267,7 @@ def test_stream_parameters():
assert time_diff == pytest.approx(seconds_between_pushes, abs=0.1) assert time_diff == pytest.approx(seconds_between_pushes, abs=0.1)
@require_package("grpc") @require_package("grpcio", "grpc")
@pytest.mark.timeout(3) # force cross-platform watchdog @pytest.mark.timeout(3) # force cross-platform watchdog
def test_stream_parameters_with_shutdown(): def test_stream_parameters_with_shutdown():
from lerobot.transport import services_pb2 from lerobot.transport import services_pb2
@@ -319,7 +319,7 @@ def test_stream_parameters_with_shutdown():
assert received_params == [b"param_batch_1", b"stop"] assert received_params == [b"param_batch_1", b"stop"]
@require_package("grpc") @require_package("grpcio", "grpc")
@pytest.mark.timeout(3) # force cross-platform watchdog @pytest.mark.timeout(3) # force cross-platform watchdog
def test_stream_parameters_waits_and_retries_on_empty_queue(): def test_stream_parameters_waits_and_retries_on_empty_queue():
import threading import threading
+31 -31
View File
@@ -26,7 +26,7 @@ from lerobot.utils.transition import Transition
from tests.utils import require_cuda, require_package from tests.utils import require_cuda, require_package
@require_package("grpc") @require_package("grpcio", "grpc")
def test_bytes_buffer_size_empty_buffer(): def test_bytes_buffer_size_empty_buffer():
from lerobot.transport.utils import bytes_buffer_size from lerobot.transport.utils import bytes_buffer_size
@@ -37,7 +37,7 @@ def test_bytes_buffer_size_empty_buffer():
assert buffer.tell() == 0 assert buffer.tell() == 0
@require_package("grpc") @require_package("grpcio", "grpc")
def test_bytes_buffer_size_small_buffer(): def test_bytes_buffer_size_small_buffer():
from lerobot.transport.utils import bytes_buffer_size from lerobot.transport.utils import bytes_buffer_size
@@ -47,7 +47,7 @@ def test_bytes_buffer_size_small_buffer():
assert buffer.tell() == 0 assert buffer.tell() == 0
@require_package("grpc") @require_package("grpcio", "grpc")
def test_bytes_buffer_size_large_buffer(): def test_bytes_buffer_size_large_buffer():
from lerobot.transport.utils import CHUNK_SIZE, bytes_buffer_size from lerobot.transport.utils import CHUNK_SIZE, bytes_buffer_size
@@ -58,7 +58,7 @@ def test_bytes_buffer_size_large_buffer():
assert buffer.tell() == 0 assert buffer.tell() == 0
@require_package("grpc") @require_package("grpcio", "grpc")
def test_send_bytes_in_chunks_empty_data(): def test_send_bytes_in_chunks_empty_data():
from lerobot.transport.utils import send_bytes_in_chunks, services_pb2 from lerobot.transport.utils import send_bytes_in_chunks, services_pb2
@@ -68,7 +68,7 @@ def test_send_bytes_in_chunks_empty_data():
assert len(chunks) == 0 assert len(chunks) == 0
@require_package("grpc") @require_package("grpcio", "grpc")
def test_single_chunk_small_data(): def test_single_chunk_small_data():
from lerobot.transport.utils import send_bytes_in_chunks, services_pb2 from lerobot.transport.utils import send_bytes_in_chunks, services_pb2
@@ -82,7 +82,7 @@ def test_single_chunk_small_data():
assert chunks[0].transfer_state == services_pb2.TransferState.TRANSFER_END assert chunks[0].transfer_state == services_pb2.TransferState.TRANSFER_END
@require_package("grpc") @require_package("grpcio", "grpc")
def test_not_silent_mode(): def test_not_silent_mode():
from lerobot.transport.utils import send_bytes_in_chunks, services_pb2 from lerobot.transport.utils import send_bytes_in_chunks, services_pb2
@@ -94,7 +94,7 @@ def test_not_silent_mode():
assert chunks[0].data == b"Some data" assert chunks[0].data == b"Some data"
@require_package("grpc") @require_package("grpcio", "grpc")
def test_send_bytes_in_chunks_large_data(): def test_send_bytes_in_chunks_large_data():
from lerobot.transport.utils import CHUNK_SIZE, send_bytes_in_chunks, services_pb2 from lerobot.transport.utils import CHUNK_SIZE, send_bytes_in_chunks, services_pb2
@@ -111,7 +111,7 @@ def test_send_bytes_in_chunks_large_data():
assert chunks[2].transfer_state == services_pb2.TransferState.TRANSFER_END assert chunks[2].transfer_state == services_pb2.TransferState.TRANSFER_END
@require_package("grpc") @require_package("grpcio", "grpc")
def test_send_bytes_in_chunks_large_data_with_exact_chunk_size(): def test_send_bytes_in_chunks_large_data_with_exact_chunk_size():
from lerobot.transport.utils import CHUNK_SIZE, send_bytes_in_chunks, services_pb2 from lerobot.transport.utils import CHUNK_SIZE, send_bytes_in_chunks, services_pb2
@@ -124,7 +124,7 @@ def test_send_bytes_in_chunks_large_data_with_exact_chunk_size():
assert chunks[0].transfer_state == services_pb2.TransferState.TRANSFER_END assert chunks[0].transfer_state == services_pb2.TransferState.TRANSFER_END
@require_package("grpc") @require_package("grpcio", "grpc")
def test_receive_bytes_in_chunks_empty_data(): def test_receive_bytes_in_chunks_empty_data():
from lerobot.transport.utils import receive_bytes_in_chunks from lerobot.transport.utils import receive_bytes_in_chunks
@@ -138,7 +138,7 @@ def test_receive_bytes_in_chunks_empty_data():
assert queue.empty() assert queue.empty()
@require_package("grpc") @require_package("grpcio", "grpc")
def test_receive_bytes_in_chunks_single_chunk(): def test_receive_bytes_in_chunks_single_chunk():
from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2 from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2
@@ -157,7 +157,7 @@ def test_receive_bytes_in_chunks_single_chunk():
assert queue.empty() assert queue.empty()
@require_package("grpc") @require_package("grpcio", "grpc")
def test_receive_bytes_in_chunks_single_not_end_chunk(): def test_receive_bytes_in_chunks_single_not_end_chunk():
from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2 from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2
@@ -175,7 +175,7 @@ def test_receive_bytes_in_chunks_single_not_end_chunk():
assert queue.empty() assert queue.empty()
@require_package("grpc") @require_package("grpcio", "grpc")
def test_receive_bytes_in_chunks_multiple_chunks(): def test_receive_bytes_in_chunks_multiple_chunks():
from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2 from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2
@@ -199,7 +199,7 @@ def test_receive_bytes_in_chunks_multiple_chunks():
assert queue.empty() assert queue.empty()
@require_package("grpc") @require_package("grpcio", "grpc")
def test_receive_bytes_in_chunks_multiple_messages(): def test_receive_bytes_in_chunks_multiple_messages():
from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2 from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2
@@ -235,7 +235,7 @@ def test_receive_bytes_in_chunks_multiple_messages():
assert queue.empty() assert queue.empty()
@require_package("grpc") @require_package("grpcio", "grpc")
def test_receive_bytes_in_chunks_shutdown_during_receive(): def test_receive_bytes_in_chunks_shutdown_during_receive():
from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2 from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2
@@ -259,7 +259,7 @@ def test_receive_bytes_in_chunks_shutdown_during_receive():
assert queue.empty() assert queue.empty()
@require_package("grpc") @require_package("grpcio", "grpc")
def test_receive_bytes_in_chunks_only_begin_chunk(): def test_receive_bytes_in_chunks_only_begin_chunk():
from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2 from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2
@@ -279,7 +279,7 @@ def test_receive_bytes_in_chunks_only_begin_chunk():
assert queue.empty() assert queue.empty()
@require_package("grpc") @require_package("grpcio", "grpc")
def test_receive_bytes_in_chunks_missing_begin(): def test_receive_bytes_in_chunks_missing_begin():
from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2 from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2
@@ -303,7 +303,7 @@ def test_receive_bytes_in_chunks_missing_begin():
# Tests for state_to_bytes and bytes_to_state_dict # Tests for state_to_bytes and bytes_to_state_dict
@require_package("grpc") @require_package("grpcio", "grpc")
def test_state_to_bytes_empty_dict(): def test_state_to_bytes_empty_dict():
from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes
@@ -314,7 +314,7 @@ def test_state_to_bytes_empty_dict():
assert reconstructed == state_dict assert reconstructed == state_dict
@require_package("grpc") @require_package("grpcio", "grpc")
def test_bytes_to_state_dict_empty_data(): def test_bytes_to_state_dict_empty_data():
from lerobot.transport.utils import bytes_to_state_dict from lerobot.transport.utils import bytes_to_state_dict
@@ -323,7 +323,7 @@ def test_bytes_to_state_dict_empty_data():
bytes_to_state_dict(b"") bytes_to_state_dict(b"")
@require_package("grpc") @require_package("grpcio", "grpc")
def test_state_to_bytes_simple_dict(): def test_state_to_bytes_simple_dict():
from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes
@@ -347,7 +347,7 @@ def test_state_to_bytes_simple_dict():
assert torch.allclose(state_dict[key], reconstructed[key]) assert torch.allclose(state_dict[key], reconstructed[key])
@require_package("grpc") @require_package("grpcio", "grpc")
def test_state_to_bytes_various_dtypes(): def test_state_to_bytes_various_dtypes():
from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes
@@ -372,7 +372,7 @@ def test_state_to_bytes_various_dtypes():
assert torch.allclose(state_dict[key], reconstructed[key]) assert torch.allclose(state_dict[key], reconstructed[key])
@require_package("grpc") @require_package("grpcio", "grpc")
def test_bytes_to_state_dict_invalid_data(): def test_bytes_to_state_dict_invalid_data():
from lerobot.transport.utils import bytes_to_state_dict from lerobot.transport.utils import bytes_to_state_dict
@@ -382,7 +382,7 @@ def test_bytes_to_state_dict_invalid_data():
@require_cuda @require_cuda
@require_package("grpc") @require_package("grpcio", "grpc")
def test_state_to_bytes_various_dtypes_cuda(): def test_state_to_bytes_various_dtypes_cuda():
from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes
@@ -407,7 +407,7 @@ def test_state_to_bytes_various_dtypes_cuda():
assert torch.allclose(state_dict[key], reconstructed[key]) assert torch.allclose(state_dict[key], reconstructed[key])
@require_package("grpc") @require_package("grpcio", "grpc")
def test_python_object_to_bytes_none(): def test_python_object_to_bytes_none():
from lerobot.transport.utils import bytes_to_python_object, python_object_to_bytes from lerobot.transport.utils import bytes_to_python_object, python_object_to_bytes
@@ -439,7 +439,7 @@ def test_python_object_to_bytes_none():
(1, 2, 3), (1, 2, 3),
], ],
) )
@require_package("grpc") @require_package("grpcio", "grpc")
def test_python_object_to_bytes_simple_types(obj): def test_python_object_to_bytes_simple_types(obj):
from lerobot.transport.utils import bytes_to_python_object, python_object_to_bytes from lerobot.transport.utils import bytes_to_python_object, python_object_to_bytes
@@ -450,7 +450,7 @@ def test_python_object_to_bytes_simple_types(obj):
assert type(reconstructed) is type(obj) assert type(reconstructed) is type(obj)
@require_package("grpc") @require_package("grpcio", "grpc")
def test_python_object_to_bytes_with_tensors(): def test_python_object_to_bytes_with_tensors():
from lerobot.transport.utils import bytes_to_python_object, python_object_to_bytes from lerobot.transport.utils import bytes_to_python_object, python_object_to_bytes
@@ -475,7 +475,7 @@ def test_python_object_to_bytes_with_tensors():
assert torch.equal(obj["nested"]["tensor2"], reconstructed["nested"]["tensor2"]) assert torch.equal(obj["nested"]["tensor2"], reconstructed["nested"]["tensor2"])
@require_package("grpc") @require_package("grpcio", "grpc")
def test_transitions_to_bytes_empty_list(): def test_transitions_to_bytes_empty_list():
from lerobot.transport.utils import bytes_to_transitions, transitions_to_bytes from lerobot.transport.utils import bytes_to_transitions, transitions_to_bytes
@@ -487,7 +487,7 @@ def test_transitions_to_bytes_empty_list():
assert isinstance(reconstructed, list) assert isinstance(reconstructed, list)
@require_package("grpc") @require_package("grpcio", "grpc")
def test_transitions_to_bytes_single_transition(): def test_transitions_to_bytes_single_transition():
from lerobot.transport.utils import bytes_to_transitions, transitions_to_bytes from lerobot.transport.utils import bytes_to_transitions, transitions_to_bytes
@@ -509,7 +509,7 @@ def test_transitions_to_bytes_single_transition():
assert_transitions_equal(transitions[0], reconstructed[0]) assert_transitions_equal(transitions[0], reconstructed[0])
@require_package("grpc") @require_package("grpcio", "grpc")
def assert_transitions_equal(t1: Transition, t2: Transition): def assert_transitions_equal(t1: Transition, t2: Transition):
"""Helper to assert two transitions are equal.""" """Helper to assert two transitions are equal."""
assert_observation_equal(t1["state"], t2["state"]) assert_observation_equal(t1["state"], t2["state"])
@@ -519,7 +519,7 @@ def assert_transitions_equal(t1: Transition, t2: Transition):
assert_observation_equal(t1["next_state"], t2["next_state"]) assert_observation_equal(t1["next_state"], t2["next_state"])
@require_package("grpc") @require_package("grpcio", "grpc")
def assert_observation_equal(o1: dict, o2: dict): def assert_observation_equal(o1: dict, o2: dict):
"""Helper to assert two observations are equal.""" """Helper to assert two observations are equal."""
assert set(o1.keys()) == set(o2.keys()) assert set(o1.keys()) == set(o2.keys())
@@ -527,7 +527,7 @@ def assert_observation_equal(o1: dict, o2: dict):
assert torch.allclose(o1[key], o2[key]) assert torch.allclose(o1[key], o2[key])
@require_package("grpc") @require_package("grpcio", "grpc")
def test_transitions_to_bytes_multiple_transitions(): def test_transitions_to_bytes_multiple_transitions():
from lerobot.transport.utils import bytes_to_transitions, transitions_to_bytes from lerobot.transport.utils import bytes_to_transitions, transitions_to_bytes
@@ -551,7 +551,7 @@ def test_transitions_to_bytes_multiple_transitions():
assert_transitions_equal(original, reconstructed_item) assert_transitions_equal(original, reconstructed_item)
@require_package("grpc") @require_package("grpcio", "grpc")
def test_receive_bytes_in_chunks_unknown_state(): def test_receive_bytes_in_chunks_unknown_state():
from lerobot.transport.utils import receive_bytes_in_chunks from lerobot.transport.utils import receive_bytes_in_chunks
+2 -2
View File
@@ -167,7 +167,7 @@ def require_package_arg(func):
return wrapper return wrapper
def require_package(package_name): def require_package(package_name, import_name=None):
""" """
Decorator that skips the test if the specified package is not installed. Decorator that skips the test if the specified package is not installed.
""" """
@@ -175,7 +175,7 @@ def require_package(package_name):
def decorator(func): def decorator(func):
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
if not is_package_available(package_name): if not is_package_available(pkg_name=package_name, import_name=import_name):
pytest.skip(f"{package_name} not installed") pytest.skip(f"{package_name} not installed")
return func(*args, **kwargs) return func(*args, **kwargs)