mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-17 09:39:47 +00:00
feat(dependecies): minimal default tag install
This commit is contained in:
@@ -23,7 +23,7 @@ from torch.multiprocessing import Event, Queue
|
||||
|
||||
from lerobot.utils.constants import OBS_STR
|
||||
from lerobot.utils.transition import Transition
|
||||
from tests.utils import require_package
|
||||
from tests.utils import skip_if_package_missing
|
||||
|
||||
|
||||
def create_learner_service_stub():
|
||||
@@ -64,7 +64,7 @@ def close_service_stub(channel, server):
|
||||
server.stop(None)
|
||||
|
||||
|
||||
@require_package("grpcio", "grpc")
|
||||
@skip_if_package_missing("grpcio", "grpc")
|
||||
def test_establish_learner_connection_success():
|
||||
from lerobot.rl.actor import establish_learner_connection
|
||||
|
||||
@@ -81,7 +81,7 @@ def test_establish_learner_connection_success():
|
||||
close_service_stub(channel, server)
|
||||
|
||||
|
||||
@require_package("grpcio", "grpc")
|
||||
@skip_if_package_missing("grpcio", "grpc")
|
||||
def test_establish_learner_connection_failure():
|
||||
from lerobot.rl.actor import establish_learner_connection
|
||||
|
||||
@@ -100,7 +100,7 @@ def test_establish_learner_connection_failure():
|
||||
close_service_stub(channel, server)
|
||||
|
||||
|
||||
@require_package("grpcio", "grpc")
|
||||
@skip_if_package_missing("grpcio", "grpc")
|
||||
def test_push_transitions_to_transport_queue():
|
||||
from lerobot.rl.actor import push_transitions_to_transport_queue
|
||||
from lerobot.transport.utils import bytes_to_transitions
|
||||
@@ -135,7 +135,7 @@ def test_push_transitions_to_transport_queue():
|
||||
assert_transitions_equal(deserialized_transition, transitions[i])
|
||||
|
||||
|
||||
@require_package("grpcio", "grpc")
|
||||
@skip_if_package_missing("grpcio", "grpc")
|
||||
@pytest.mark.timeout(3) # force cross-platform watchdog
|
||||
def test_transitions_stream():
|
||||
from lerobot.rl.actor import transitions_stream
|
||||
@@ -167,7 +167,7 @@ def test_transitions_stream():
|
||||
assert streamed_data[2].data == b"transition_data_3"
|
||||
|
||||
|
||||
@require_package("grpcio", "grpc")
|
||||
@skip_if_package_missing("grpcio", "grpc")
|
||||
@pytest.mark.timeout(3) # force cross-platform watchdog
|
||||
def test_interactions_stream():
|
||||
from lerobot.rl.actor import interactions_stream
|
||||
|
||||
@@ -26,7 +26,7 @@ from lerobot.configs.train import TrainRLServerPipelineConfig
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.utils.constants import OBS_STR
|
||||
from lerobot.utils.transition import Transition
|
||||
from tests.utils import require_package
|
||||
from tests.utils import skip_if_package_missing
|
||||
|
||||
|
||||
def create_test_transitions(count: int = 3) -> list[Transition]:
|
||||
@@ -88,7 +88,7 @@ def cfg():
|
||||
return cfg
|
||||
|
||||
|
||||
@require_package("grpcio", "grpc")
|
||||
@skip_if_package_missing("grpcio", "grpc")
|
||||
@pytest.mark.timeout(10) # force cross-platform watchdog
|
||||
def test_end_to_end_transitions_flow(cfg):
|
||||
from lerobot.rl.actor import (
|
||||
@@ -150,7 +150,7 @@ def test_end_to_end_transitions_flow(cfg):
|
||||
assert_transitions_equal(transition, input_transitions[i])
|
||||
|
||||
|
||||
@require_package("grpcio", "grpc")
|
||||
@skip_if_package_missing("grpcio", "grpc")
|
||||
@pytest.mark.timeout(10)
|
||||
def test_end_to_end_interactions_flow(cfg):
|
||||
from lerobot.rl.actor import (
|
||||
@@ -223,7 +223,7 @@ def test_end_to_end_interactions_flow(cfg):
|
||||
assert received == expected
|
||||
|
||||
|
||||
@require_package("grpcio", "grpc")
|
||||
@skip_if_package_missing("grpcio", "grpc")
|
||||
@pytest.mark.parametrize("data_size", ["small", "large"])
|
||||
@pytest.mark.timeout(10)
|
||||
def test_end_to_end_parameters_flow(cfg, data_size):
|
||||
|
||||
@@ -20,7 +20,7 @@ from multiprocessing import Event, Queue
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.utils import require_package # our gRPC servicer class
|
||||
from tests.utils import skip_if_package_missing # our gRPC servicer class
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
@@ -39,7 +39,7 @@ def learner_service_stub():
|
||||
close_learner_service_stub(channel, server)
|
||||
|
||||
|
||||
@require_package("grpcio", "grpc")
|
||||
@skip_if_package_missing("grpcio", "grpc")
|
||||
def create_learner_service_stub(
|
||||
shutdown_event: Event,
|
||||
parameters_queue: Queue,
|
||||
@@ -75,7 +75,7 @@ def create_learner_service_stub(
|
||||
return services_pb2_grpc.LearnerServiceStub(channel), channel, server
|
||||
|
||||
|
||||
@require_package("grpcio", "grpc")
|
||||
@skip_if_package_missing("grpcio", "grpc")
|
||||
def close_learner_service_stub(channel, server):
|
||||
channel.close()
|
||||
server.stop(None)
|
||||
@@ -91,7 +91,7 @@ def test_ready_method(learner_service_stub):
|
||||
assert response == services_pb2.Empty()
|
||||
|
||||
|
||||
@require_package("grpcio", "grpc")
|
||||
@skip_if_package_missing("grpcio", "grpc")
|
||||
@pytest.mark.timeout(3) # force cross-platform watchdog
|
||||
def test_send_interactions():
|
||||
from lerobot.transport import services_pb2
|
||||
@@ -135,7 +135,7 @@ def test_send_interactions():
|
||||
assert interactions == [b"123", b"4", b"5", b"678"]
|
||||
|
||||
|
||||
@require_package("grpcio", "grpc")
|
||||
@skip_if_package_missing("grpcio", "grpc")
|
||||
@pytest.mark.timeout(3) # force cross-platform watchdog
|
||||
def test_send_transitions():
|
||||
from lerobot.transport import services_pb2
|
||||
@@ -181,7 +181,7 @@ def test_send_transitions():
|
||||
assert transitions == [b"transition_1transition_2transition_3", b"batch_1batch_2"]
|
||||
|
||||
|
||||
@require_package("grpcio", "grpc")
|
||||
@skip_if_package_missing("grpcio", "grpc")
|
||||
@pytest.mark.timeout(3) # force cross-platform watchdog
|
||||
def test_send_transitions_empty_stream():
|
||||
from lerobot.transport import services_pb2
|
||||
@@ -209,7 +209,7 @@ def test_send_transitions_empty_stream():
|
||||
assert transitions_queue.empty()
|
||||
|
||||
|
||||
@require_package("grpcio", "grpc")
|
||||
@skip_if_package_missing("grpcio", "grpc")
|
||||
@pytest.mark.timeout(10) # force cross-platform watchdog
|
||||
def test_stream_parameters():
|
||||
import time
|
||||
@@ -267,7 +267,7 @@ def test_stream_parameters():
|
||||
assert time_diff == pytest.approx(seconds_between_pushes, abs=0.1)
|
||||
|
||||
|
||||
@require_package("grpcio", "grpc")
|
||||
@skip_if_package_missing("grpcio", "grpc")
|
||||
@pytest.mark.timeout(3) # force cross-platform watchdog
|
||||
def test_stream_parameters_with_shutdown():
|
||||
from lerobot.transport import services_pb2
|
||||
@@ -319,7 +319,7 @@ def test_stream_parameters_with_shutdown():
|
||||
assert received_params == [b"param_batch_1", b"stop"]
|
||||
|
||||
|
||||
@require_package("grpcio", "grpc")
|
||||
@skip_if_package_missing("grpcio", "grpc")
|
||||
@pytest.mark.timeout(3) # force cross-platform watchdog
|
||||
def test_stream_parameters_waits_and_retries_on_empty_queue():
|
||||
import threading
|
||||
|
||||
Reference in New Issue
Block a user