test: add dataset guard + fix imports

This commit is contained in:
Steven Palma
2026-04-20 00:36:02 +02:00
parent 4130d4a4a5
commit 8e21268c29
8 changed files with 56 additions and 34 deletions
+18 -19
View File
@@ -22,6 +22,8 @@ from unittest.mock import MagicMock
import pytest
import torch
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
# ---------------------------------------------------------------------------
# Import smoke tests
# ---------------------------------------------------------------------------
@@ -54,7 +56,7 @@ def test_strategies_submodule_imports():
def test_strategy_config_types():
from lerobot.rollout.configs import (
from lerobot.rollout import (
BaseStrategyConfig,
DAggerStrategyConfig,
HighlightStrategyConfig,
@@ -68,14 +70,14 @@ def test_strategy_config_types():
def test_dagger_config_invalid_input_device():
from lerobot.rollout.configs import DAggerStrategyConfig
from lerobot.rollout import DAggerStrategyConfig
with pytest.raises(ValueError, match="input_device must be 'keyboard' or 'pedal'"):
DAggerStrategyConfig(input_device="joystick")
def test_dagger_config_defaults():
from lerobot.rollout.configs import DAggerStrategyConfig
from lerobot.rollout import DAggerStrategyConfig
cfg = DAggerStrategyConfig()
assert cfg.num_episodes == 10
@@ -95,7 +97,7 @@ def test_inference_config_types():
def test_sentry_config_defaults():
from lerobot.rollout.configs import SentryStrategyConfig
from lerobot.rollout import SentryStrategyConfig
cfg = SentryStrategyConfig()
assert cfg.upload_every_n_episodes == 5
@@ -108,7 +110,7 @@ def test_sentry_config_defaults():
def test_ring_buffer_append_and_eviction():
from lerobot.rollout.ring_buffer import RolloutRingBuffer
from lerobot.rollout import RolloutRingBuffer
buf = RolloutRingBuffer(max_seconds=0.5, max_memory_mb=100.0, fps=10.0)
# max_frames = 5
@@ -118,7 +120,7 @@ def test_ring_buffer_append_and_eviction():
def test_ring_buffer_drain():
from lerobot.rollout.ring_buffer import RolloutRingBuffer
from lerobot.rollout import RolloutRingBuffer
buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0)
for i in range(3):
@@ -130,7 +132,7 @@ def test_ring_buffer_drain():
def test_ring_buffer_clear():
from lerobot.rollout.ring_buffer import RolloutRingBuffer
from lerobot.rollout import RolloutRingBuffer
buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0)
buf.append({"val": 1})
@@ -140,7 +142,7 @@ def test_ring_buffer_clear():
def test_ring_buffer_tensor_bytes():
from lerobot.rollout.ring_buffer import RolloutRingBuffer
from lerobot.rollout import RolloutRingBuffer
buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0)
t = torch.zeros(100, dtype=torch.float32) # 400 bytes
@@ -154,7 +156,7 @@ def test_ring_buffer_tensor_bytes():
def test_thread_safe_robot_delegates():
from lerobot.rollout.robot_wrapper import ThreadSafeRobot
from lerobot.rollout import ThreadSafeRobot
from tests.mocks.mock_robot import MockRobot, MockRobotConfig
robot = MockRobot(MockRobotConfig(n_motors=3))
@@ -174,7 +176,7 @@ def test_thread_safe_robot_delegates():
def test_thread_safe_robot_properties():
from lerobot.rollout.robot_wrapper import ThreadSafeRobot
from lerobot.rollout import ThreadSafeRobot
from tests.mocks.mock_robot import MockRobot, MockRobotConfig
robot = MockRobot(MockRobotConfig(n_motors=3))
@@ -196,11 +198,8 @@ def test_thread_safe_robot_properties():
def test_create_strategy_dispatches():
from lerobot.rollout.configs import BaseStrategyConfig, DAggerStrategyConfig, SentryStrategyConfig
from lerobot.rollout.strategies import create_strategy
from lerobot.rollout.strategies.base import BaseStrategy
from lerobot.rollout.strategies.dagger import DAggerStrategy
from lerobot.rollout.strategies.sentry import SentryStrategy
from lerobot.rollout import BaseStrategyConfig, DAggerStrategyConfig, SentryStrategyConfig
from lerobot.rollout.strategies import BaseStrategy, DAggerStrategy, SentryStrategy, create_strategy
assert isinstance(create_strategy(BaseStrategyConfig()), BaseStrategy)
assert isinstance(create_strategy(SentryStrategyConfig()), SentryStrategy)
@@ -280,7 +279,7 @@ def test_safe_push_to_hub():
def test_dagger_full_transition_cycle():
from lerobot.rollout.strategies.dagger import DAggerEvents, DAggerPhase
from lerobot.rollout.strategies import DAggerEvents, DAggerPhase
events = DAggerEvents()
assert events.phase == DAggerPhase.AUTONOMOUS
@@ -307,7 +306,7 @@ def test_dagger_full_transition_cycle():
def test_dagger_invalid_transition_ignored():
from lerobot.rollout.strategies.dagger import DAggerEvents, DAggerPhase
from lerobot.rollout.strategies import DAggerEvents, DAggerPhase
events = DAggerEvents()
events.request_transition("correction") # Not valid from AUTONOMOUS
@@ -316,7 +315,7 @@ def test_dagger_invalid_transition_ignored():
def test_dagger_events_reset():
from lerobot.rollout.strategies.dagger import DAggerEvents, DAggerPhase
from lerobot.rollout.strategies import DAggerEvents, DAggerPhase
events = DAggerEvents()
events.request_transition("pause_resume")
@@ -333,7 +332,7 @@ def test_dagger_events_reset():
def test_rollout_context_fields():
from lerobot.rollout.context import RolloutContext
from lerobot.rollout import RolloutContext
field_names = {f.name for f in dataclasses.fields(RolloutContext)}
assert field_names == {"runtime", "hardware", "policy", "processors", "data"}