test: add dataset guard + fix imports

This commit is contained in:
Steven Palma
2026-04-20 00:36:02 +02:00
parent 4130d4a4a5
commit 8e21268c29
8 changed files with 56 additions and 34 deletions
+2 -3
View File
@@ -227,10 +227,9 @@ See the [Real-Time Chunking](./rtc) guide for details on tuning RTC parameters.
For custom deployments (e.g. with kinematics processors), use the rollout module API directly:
```python
from lerobot.rollout.configs import BaseStrategyConfig, RolloutConfig
from lerobot.rollout.context import build_rollout_context
from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context
from lerobot.rollout.inference import SyncInferenceConfig
from lerobot.rollout.strategies.base import BaseStrategy
from lerobot.rollout.strategies import BaseStrategy
from lerobot.utils.process import ProcessSignalHandler
cfg = RolloutConfig(
+2 -3
View File
@@ -24,10 +24,9 @@ recording, upload, and human-in-the-loop variants, see ``lerobot-rollout``.
from lerobot.configs import PreTrainedConfig
from lerobot.robots.lekiwi import LeKiwiClientConfig
from lerobot.rollout.configs import BaseStrategyConfig, RolloutConfig
from lerobot.rollout.context import build_rollout_context
from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context
from lerobot.rollout.inference import SyncInferenceConfig
from lerobot.rollout.strategies.base import BaseStrategy
from lerobot.rollout.strategies import BaseStrategy
from lerobot.utils.process import ProcessSignalHandler
from lerobot.utils.utils import init_logging
+2 -3
View File
@@ -40,10 +40,9 @@ from lerobot.robots.so_follower.robot_kinematic_processor import (
ForwardKinematicsJointsToEE,
InverseKinematicsEEToJoints,
)
from lerobot.rollout.configs import BaseStrategyConfig, RolloutConfig
from lerobot.rollout.context import build_rollout_context
from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context
from lerobot.rollout.inference import SyncInferenceConfig
from lerobot.rollout.strategies.base import BaseStrategy
from lerobot.rollout.strategies import BaseStrategy
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.process import ProcessSignalHandler
from lerobot.utils.utils import init_logging
+2 -3
View File
@@ -38,10 +38,9 @@ from lerobot.robots.so_follower.robot_kinematic_processor import (
ForwardKinematicsJointsToEE,
InverseKinematicsEEToJoints,
)
from lerobot.rollout.configs import BaseStrategyConfig, RolloutConfig
from lerobot.rollout.context import build_rollout_context
from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context
from lerobot.rollout.inference import SyncInferenceConfig
from lerobot.rollout.strategies.base import BaseStrategy
from lerobot.rollout.strategies import BaseStrategy
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.process import ProcessSignalHandler
from lerobot.utils.utils import init_logging
+19 -2
View File
@@ -14,6 +14,10 @@
"""Policy deployment engine with pluggable rollout strategies."""
from lerobot.utils.import_utils import require_package
require_package("datasets", extra="dataset")
from .configs import (
BaseStrategyConfig,
DAggerKeyboardConfig,
@@ -25,7 +29,15 @@ from .configs import (
RolloutStrategyConfig,
SentryStrategyConfig,
)
from .context import RolloutContext, build_rollout_context
from .context import (
DatasetContext,
HardwareContext,
PolicyContext,
ProcessorContext,
RolloutContext,
RuntimeContext,
build_rollout_context,
)
from .inference import (
InferenceEngine,
InferenceEngineConfig,
@@ -44,17 +56,22 @@ __all__ = [
"DAggerKeyboardConfig",
"DAggerPedalConfig",
"DAggerStrategyConfig",
"DatasetContext",
"DatasetRecordConfig",
"HardwareContext",
"HighlightStrategyConfig",
"InferenceEngine",
"InferenceEngineConfig",
"PolicyContext",
"ProcessorContext",
"RTCInferenceConfig",
"RTCInferenceEngine",
"RolloutConfig",
"RolloutContext",
"DatasetRecordConfig",
"RolloutRingBuffer",
"RolloutStrategy",
"RolloutStrategyConfig",
"RuntimeContext",
"SentryStrategyConfig",
"SyncInferenceConfig",
"SyncInferenceEngine",
@@ -14,11 +14,21 @@
"""Rollout strategies — public API re-exports."""
from .base import BaseStrategy
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
from .dagger import DAggerEvents, DAggerPhase, DAggerStrategy
from .factory import create_strategy
from .highlight import HighlightStrategy
from .sentry import SentryStrategy
__all__ = [
"BaseStrategy",
"DAggerEvents",
"DAggerPhase",
"DAggerStrategy",
"HighlightStrategy",
"RolloutStrategy",
"SentryStrategy",
"create_strategy",
"estimate_max_episode_seconds",
"safe_push_to_hub",
+1 -1
View File
@@ -25,7 +25,7 @@ from .highlight import HighlightStrategy
from .sentry import SentryStrategy
if TYPE_CHECKING:
from lerobot.rollout.configs import RolloutStrategyConfig
from lerobot.rollout import RolloutStrategyConfig
def create_strategy(config: RolloutStrategyConfig) -> RolloutStrategy:
+18 -19
View File
@@ -22,6 +22,8 @@ from unittest.mock import MagicMock
import pytest
import torch
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
# ---------------------------------------------------------------------------
# Import smoke tests
# ---------------------------------------------------------------------------
@@ -54,7 +56,7 @@ def test_strategies_submodule_imports():
def test_strategy_config_types():
from lerobot.rollout.configs import (
from lerobot.rollout import (
BaseStrategyConfig,
DAggerStrategyConfig,
HighlightStrategyConfig,
@@ -68,14 +70,14 @@ def test_strategy_config_types():
def test_dagger_config_invalid_input_device():
from lerobot.rollout.configs import DAggerStrategyConfig
from lerobot.rollout import DAggerStrategyConfig
with pytest.raises(ValueError, match="input_device must be 'keyboard' or 'pedal'"):
DAggerStrategyConfig(input_device="joystick")
def test_dagger_config_defaults():
from lerobot.rollout.configs import DAggerStrategyConfig
from lerobot.rollout import DAggerStrategyConfig
cfg = DAggerStrategyConfig()
assert cfg.num_episodes == 10
@@ -95,7 +97,7 @@ def test_inference_config_types():
def test_sentry_config_defaults():
from lerobot.rollout.configs import SentryStrategyConfig
from lerobot.rollout import SentryStrategyConfig
cfg = SentryStrategyConfig()
assert cfg.upload_every_n_episodes == 5
@@ -108,7 +110,7 @@ def test_sentry_config_defaults():
def test_ring_buffer_append_and_eviction():
from lerobot.rollout.ring_buffer import RolloutRingBuffer
from lerobot.rollout import RolloutRingBuffer
buf = RolloutRingBuffer(max_seconds=0.5, max_memory_mb=100.0, fps=10.0)
# max_frames = 5
@@ -118,7 +120,7 @@ def test_ring_buffer_append_and_eviction():
def test_ring_buffer_drain():
from lerobot.rollout.ring_buffer import RolloutRingBuffer
from lerobot.rollout import RolloutRingBuffer
buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0)
for i in range(3):
@@ -130,7 +132,7 @@ def test_ring_buffer_drain():
def test_ring_buffer_clear():
from lerobot.rollout.ring_buffer import RolloutRingBuffer
from lerobot.rollout import RolloutRingBuffer
buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0)
buf.append({"val": 1})
@@ -140,7 +142,7 @@ def test_ring_buffer_clear():
def test_ring_buffer_tensor_bytes():
from lerobot.rollout.ring_buffer import RolloutRingBuffer
from lerobot.rollout import RolloutRingBuffer
buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0)
t = torch.zeros(100, dtype=torch.float32) # 400 bytes
@@ -154,7 +156,7 @@ def test_ring_buffer_tensor_bytes():
def test_thread_safe_robot_delegates():
from lerobot.rollout.robot_wrapper import ThreadSafeRobot
from lerobot.rollout import ThreadSafeRobot
from tests.mocks.mock_robot import MockRobot, MockRobotConfig
robot = MockRobot(MockRobotConfig(n_motors=3))
@@ -174,7 +176,7 @@ def test_thread_safe_robot_delegates():
def test_thread_safe_robot_properties():
from lerobot.rollout.robot_wrapper import ThreadSafeRobot
from lerobot.rollout import ThreadSafeRobot
from tests.mocks.mock_robot import MockRobot, MockRobotConfig
robot = MockRobot(MockRobotConfig(n_motors=3))
@@ -196,11 +198,8 @@ def test_thread_safe_robot_properties():
def test_create_strategy_dispatches():
from lerobot.rollout.configs import BaseStrategyConfig, DAggerStrategyConfig, SentryStrategyConfig
from lerobot.rollout.strategies import create_strategy
from lerobot.rollout.strategies.base import BaseStrategy
from lerobot.rollout.strategies.dagger import DAggerStrategy
from lerobot.rollout.strategies.sentry import SentryStrategy
from lerobot.rollout import BaseStrategyConfig, DAggerStrategyConfig, SentryStrategyConfig
from lerobot.rollout.strategies import BaseStrategy, DAggerStrategy, SentryStrategy, create_strategy
assert isinstance(create_strategy(BaseStrategyConfig()), BaseStrategy)
assert isinstance(create_strategy(SentryStrategyConfig()), SentryStrategy)
@@ -280,7 +279,7 @@ def test_safe_push_to_hub():
def test_dagger_full_transition_cycle():
from lerobot.rollout.strategies.dagger import DAggerEvents, DAggerPhase
from lerobot.rollout.strategies import DAggerEvents, DAggerPhase
events = DAggerEvents()
assert events.phase == DAggerPhase.AUTONOMOUS
@@ -307,7 +306,7 @@ def test_dagger_full_transition_cycle():
def test_dagger_invalid_transition_ignored():
from lerobot.rollout.strategies.dagger import DAggerEvents, DAggerPhase
from lerobot.rollout.strategies import DAggerEvents, DAggerPhase
events = DAggerEvents()
events.request_transition("correction") # Not valid from AUTONOMOUS
@@ -316,7 +315,7 @@ def test_dagger_invalid_transition_ignored():
def test_dagger_events_reset():
from lerobot.rollout.strategies.dagger import DAggerEvents, DAggerPhase
from lerobot.rollout.strategies import DAggerEvents, DAggerPhase
events = DAggerEvents()
events.request_transition("pause_resume")
@@ -333,7 +332,7 @@ def test_dagger_events_reset():
def test_rollout_context_fields():
from lerobot.rollout.context import RolloutContext
from lerobot.rollout import RolloutContext
field_names = {f.name for f in dataclasses.fields(RolloutContext)}
assert field_names == {"runtime", "hardware", "policy", "processors", "data"}