fix test imports

This commit is contained in:
Steven Palma
2026-04-11 21:07:53 +02:00
parent c9636bb53f
commit 5940126fb5
9 changed files with 71 additions and 42 deletions
+7 -3
View File
@@ -16,10 +16,14 @@ import math
import pickle
import time
import numpy as np
import torch
import pytest
from lerobot.async_inference.helpers import (
pytest.importorskip("grpc")
import numpy as np # noqa: E402
import torch # noqa: E402
from lerobot.async_inference.helpers import ( # noqa: E402
FPSTracker,
TimedAction,
TimedObservation,
+6 -2
View File
@@ -18,9 +18,13 @@ import threading
import time
from queue import Queue
from torch.multiprocessing import Queue as TorchMPQueue
import pytest
from lerobot.rl.queue import get_last_item_from_queue
pytest.importorskip("grpc")
from torch.multiprocessing import Queue as TorchMPQueue # noqa: E402
from lerobot.rl.queue import get_last_item_from_queue # noqa: E402
def test_get_last_item_single_item():
+3 -1
View File
@@ -22,7 +22,9 @@ from unittest.mock import patch
import pytest
from lerobot.rl.process import ProcessSignalHandler
pytest.importorskip("grpc")
from lerobot.rl.process import ProcessSignalHandler # noqa: E402
# Fixture to reset shutdown_event_counter and original signal handlers before and after each test
+8 -5
View File
@@ -18,12 +18,15 @@ import sys
from collections.abc import Callable
import pytest
import torch
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.rl.buffer import BatchTransition, ReplayBuffer, random_crop_vectorized
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, OBS_STATE, OBS_STR, REWARD
from tests.fixtures.constants import DUMMY_REPO_ID
pytest.importorskip("grpc")
import torch # noqa: E402
from lerobot.datasets.lerobot_dataset import LeRobotDataset # noqa: E402
from lerobot.rl.buffer import BatchTransition, ReplayBuffer, random_crop_vectorized # noqa: E402
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, OBS_STATE, OBS_STR, REWARD # noqa: E402
from tests.fixtures.constants import DUMMY_REPO_ID # noqa: E402
def state_dims() -> list[str]: