test: add dataset guard + fix imports

This commit is contained in:
Steven Palma
2026-04-20 00:36:02 +02:00
parent 4130d4a4a5
commit 8e21268c29
8 changed files with 56 additions and 34 deletions
+2 -3
View File
@@ -227,10 +227,9 @@ See the [Real-Time Chunking](./rtc) guide for details on tuning RTC parameters.
For custom deployments (e.g. with kinematics processors), use the rollout module API directly: For custom deployments (e.g. with kinematics processors), use the rollout module API directly:
```python ```python
from lerobot.rollout.configs import BaseStrategyConfig, RolloutConfig from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context
from lerobot.rollout.context import build_rollout_context
from lerobot.rollout.inference import SyncInferenceConfig from lerobot.rollout.inference import SyncInferenceConfig
from lerobot.rollout.strategies.base import BaseStrategy from lerobot.rollout.strategies import BaseStrategy
from lerobot.utils.process import ProcessSignalHandler from lerobot.utils.process import ProcessSignalHandler
cfg = RolloutConfig( cfg = RolloutConfig(
+2 -3
View File
@@ -24,10 +24,9 @@ recording, upload, and human-in-the-loop variants, see ``lerobot-rollout``.
from lerobot.configs import PreTrainedConfig from lerobot.configs import PreTrainedConfig
from lerobot.robots.lekiwi import LeKiwiClientConfig from lerobot.robots.lekiwi import LeKiwiClientConfig
from lerobot.rollout.configs import BaseStrategyConfig, RolloutConfig from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context
from lerobot.rollout.context import build_rollout_context
from lerobot.rollout.inference import SyncInferenceConfig from lerobot.rollout.inference import SyncInferenceConfig
from lerobot.rollout.strategies.base import BaseStrategy from lerobot.rollout.strategies import BaseStrategy
from lerobot.utils.process import ProcessSignalHandler from lerobot.utils.process import ProcessSignalHandler
from lerobot.utils.utils import init_logging from lerobot.utils.utils import init_logging
+2 -3
View File
@@ -40,10 +40,9 @@ from lerobot.robots.so_follower.robot_kinematic_processor import (
ForwardKinematicsJointsToEE, ForwardKinematicsJointsToEE,
InverseKinematicsEEToJoints, InverseKinematicsEEToJoints,
) )
from lerobot.rollout.configs import BaseStrategyConfig, RolloutConfig from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context
from lerobot.rollout.context import build_rollout_context
from lerobot.rollout.inference import SyncInferenceConfig from lerobot.rollout.inference import SyncInferenceConfig
from lerobot.rollout.strategies.base import BaseStrategy from lerobot.rollout.strategies import BaseStrategy
from lerobot.types import RobotAction, RobotObservation from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.process import ProcessSignalHandler from lerobot.utils.process import ProcessSignalHandler
from lerobot.utils.utils import init_logging from lerobot.utils.utils import init_logging
+2 -3
View File
@@ -38,10 +38,9 @@ from lerobot.robots.so_follower.robot_kinematic_processor import (
ForwardKinematicsJointsToEE, ForwardKinematicsJointsToEE,
InverseKinematicsEEToJoints, InverseKinematicsEEToJoints,
) )
from lerobot.rollout.configs import BaseStrategyConfig, RolloutConfig from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context
from lerobot.rollout.context import build_rollout_context
from lerobot.rollout.inference import SyncInferenceConfig from lerobot.rollout.inference import SyncInferenceConfig
from lerobot.rollout.strategies.base import BaseStrategy from lerobot.rollout.strategies import BaseStrategy
from lerobot.types import RobotAction, RobotObservation from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.process import ProcessSignalHandler from lerobot.utils.process import ProcessSignalHandler
from lerobot.utils.utils import init_logging from lerobot.utils.utils import init_logging
+19 -2
View File
@@ -14,6 +14,10 @@
"""Policy deployment engine with pluggable rollout strategies.""" """Policy deployment engine with pluggable rollout strategies."""
from lerobot.utils.import_utils import require_package
require_package("datasets", extra="dataset")
from .configs import ( from .configs import (
BaseStrategyConfig, BaseStrategyConfig,
DAggerKeyboardConfig, DAggerKeyboardConfig,
@@ -25,7 +29,15 @@ from .configs import (
RolloutStrategyConfig, RolloutStrategyConfig,
SentryStrategyConfig, SentryStrategyConfig,
) )
from .context import RolloutContext, build_rollout_context from .context import (
DatasetContext,
HardwareContext,
PolicyContext,
ProcessorContext,
RolloutContext,
RuntimeContext,
build_rollout_context,
)
from .inference import ( from .inference import (
InferenceEngine, InferenceEngine,
InferenceEngineConfig, InferenceEngineConfig,
@@ -44,17 +56,22 @@ __all__ = [
"DAggerKeyboardConfig", "DAggerKeyboardConfig",
"DAggerPedalConfig", "DAggerPedalConfig",
"DAggerStrategyConfig", "DAggerStrategyConfig",
"DatasetContext",
"DatasetRecordConfig",
"HardwareContext",
"HighlightStrategyConfig", "HighlightStrategyConfig",
"InferenceEngine", "InferenceEngine",
"InferenceEngineConfig", "InferenceEngineConfig",
"PolicyContext",
"ProcessorContext",
"RTCInferenceConfig", "RTCInferenceConfig",
"RTCInferenceEngine", "RTCInferenceEngine",
"RolloutConfig", "RolloutConfig",
"RolloutContext", "RolloutContext",
"DatasetRecordConfig",
"RolloutRingBuffer", "RolloutRingBuffer",
"RolloutStrategy", "RolloutStrategy",
"RolloutStrategyConfig", "RolloutStrategyConfig",
"RuntimeContext",
"SentryStrategyConfig", "SentryStrategyConfig",
"SyncInferenceConfig", "SyncInferenceConfig",
"SyncInferenceEngine", "SyncInferenceEngine",
@@ -14,11 +14,21 @@
"""Rollout strategies — public API re-exports.""" """Rollout strategies — public API re-exports."""
from .base import BaseStrategy
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
from .dagger import DAggerEvents, DAggerPhase, DAggerStrategy
from .factory import create_strategy from .factory import create_strategy
from .highlight import HighlightStrategy
from .sentry import SentryStrategy
__all__ = [ __all__ = [
"BaseStrategy",
"DAggerEvents",
"DAggerPhase",
"DAggerStrategy",
"HighlightStrategy",
"RolloutStrategy", "RolloutStrategy",
"SentryStrategy",
"create_strategy", "create_strategy",
"estimate_max_episode_seconds", "estimate_max_episode_seconds",
"safe_push_to_hub", "safe_push_to_hub",
+1 -1
View File
@@ -25,7 +25,7 @@ from .highlight import HighlightStrategy
from .sentry import SentryStrategy from .sentry import SentryStrategy
if TYPE_CHECKING: if TYPE_CHECKING:
from lerobot.rollout.configs import RolloutStrategyConfig from lerobot.rollout import RolloutStrategyConfig
def create_strategy(config: RolloutStrategyConfig) -> RolloutStrategy: def create_strategy(config: RolloutStrategyConfig) -> RolloutStrategy:
+18 -19
View File
@@ -22,6 +22,8 @@ from unittest.mock import MagicMock
import pytest import pytest
import torch import torch
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Import smoke tests # Import smoke tests
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -54,7 +56,7 @@ def test_strategies_submodule_imports():
def test_strategy_config_types(): def test_strategy_config_types():
from lerobot.rollout.configs import ( from lerobot.rollout import (
BaseStrategyConfig, BaseStrategyConfig,
DAggerStrategyConfig, DAggerStrategyConfig,
HighlightStrategyConfig, HighlightStrategyConfig,
@@ -68,14 +70,14 @@ def test_strategy_config_types():
def test_dagger_config_invalid_input_device(): def test_dagger_config_invalid_input_device():
from lerobot.rollout.configs import DAggerStrategyConfig from lerobot.rollout import DAggerStrategyConfig
with pytest.raises(ValueError, match="input_device must be 'keyboard' or 'pedal'"): with pytest.raises(ValueError, match="input_device must be 'keyboard' or 'pedal'"):
DAggerStrategyConfig(input_device="joystick") DAggerStrategyConfig(input_device="joystick")
def test_dagger_config_defaults(): def test_dagger_config_defaults():
from lerobot.rollout.configs import DAggerStrategyConfig from lerobot.rollout import DAggerStrategyConfig
cfg = DAggerStrategyConfig() cfg = DAggerStrategyConfig()
assert cfg.num_episodes == 10 assert cfg.num_episodes == 10
@@ -95,7 +97,7 @@ def test_inference_config_types():
def test_sentry_config_defaults(): def test_sentry_config_defaults():
from lerobot.rollout.configs import SentryStrategyConfig from lerobot.rollout import SentryStrategyConfig
cfg = SentryStrategyConfig() cfg = SentryStrategyConfig()
assert cfg.upload_every_n_episodes == 5 assert cfg.upload_every_n_episodes == 5
@@ -108,7 +110,7 @@ def test_sentry_config_defaults():
def test_ring_buffer_append_and_eviction(): def test_ring_buffer_append_and_eviction():
from lerobot.rollout.ring_buffer import RolloutRingBuffer from lerobot.rollout import RolloutRingBuffer
buf = RolloutRingBuffer(max_seconds=0.5, max_memory_mb=100.0, fps=10.0) buf = RolloutRingBuffer(max_seconds=0.5, max_memory_mb=100.0, fps=10.0)
# max_frames = 5 # max_frames = 5
@@ -118,7 +120,7 @@ def test_ring_buffer_append_and_eviction():
def test_ring_buffer_drain(): def test_ring_buffer_drain():
from lerobot.rollout.ring_buffer import RolloutRingBuffer from lerobot.rollout import RolloutRingBuffer
buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0) buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0)
for i in range(3): for i in range(3):
@@ -130,7 +132,7 @@ def test_ring_buffer_drain():
def test_ring_buffer_clear(): def test_ring_buffer_clear():
from lerobot.rollout.ring_buffer import RolloutRingBuffer from lerobot.rollout import RolloutRingBuffer
buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0) buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0)
buf.append({"val": 1}) buf.append({"val": 1})
@@ -140,7 +142,7 @@ def test_ring_buffer_clear():
def test_ring_buffer_tensor_bytes(): def test_ring_buffer_tensor_bytes():
from lerobot.rollout.ring_buffer import RolloutRingBuffer from lerobot.rollout import RolloutRingBuffer
buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0) buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0)
t = torch.zeros(100, dtype=torch.float32) # 400 bytes t = torch.zeros(100, dtype=torch.float32) # 400 bytes
@@ -154,7 +156,7 @@ def test_ring_buffer_tensor_bytes():
def test_thread_safe_robot_delegates(): def test_thread_safe_robot_delegates():
from lerobot.rollout.robot_wrapper import ThreadSafeRobot from lerobot.rollout import ThreadSafeRobot
from tests.mocks.mock_robot import MockRobot, MockRobotConfig from tests.mocks.mock_robot import MockRobot, MockRobotConfig
robot = MockRobot(MockRobotConfig(n_motors=3)) robot = MockRobot(MockRobotConfig(n_motors=3))
@@ -174,7 +176,7 @@ def test_thread_safe_robot_delegates():
def test_thread_safe_robot_properties(): def test_thread_safe_robot_properties():
from lerobot.rollout.robot_wrapper import ThreadSafeRobot from lerobot.rollout import ThreadSafeRobot
from tests.mocks.mock_robot import MockRobot, MockRobotConfig from tests.mocks.mock_robot import MockRobot, MockRobotConfig
robot = MockRobot(MockRobotConfig(n_motors=3)) robot = MockRobot(MockRobotConfig(n_motors=3))
@@ -196,11 +198,8 @@ def test_thread_safe_robot_properties():
def test_create_strategy_dispatches(): def test_create_strategy_dispatches():
from lerobot.rollout.configs import BaseStrategyConfig, DAggerStrategyConfig, SentryStrategyConfig from lerobot.rollout import BaseStrategyConfig, DAggerStrategyConfig, SentryStrategyConfig
from lerobot.rollout.strategies import create_strategy from lerobot.rollout.strategies import BaseStrategy, DAggerStrategy, SentryStrategy, create_strategy
from lerobot.rollout.strategies.base import BaseStrategy
from lerobot.rollout.strategies.dagger import DAggerStrategy
from lerobot.rollout.strategies.sentry import SentryStrategy
assert isinstance(create_strategy(BaseStrategyConfig()), BaseStrategy) assert isinstance(create_strategy(BaseStrategyConfig()), BaseStrategy)
assert isinstance(create_strategy(SentryStrategyConfig()), SentryStrategy) assert isinstance(create_strategy(SentryStrategyConfig()), SentryStrategy)
@@ -280,7 +279,7 @@ def test_safe_push_to_hub():
def test_dagger_full_transition_cycle(): def test_dagger_full_transition_cycle():
from lerobot.rollout.strategies.dagger import DAggerEvents, DAggerPhase from lerobot.rollout.strategies import DAggerEvents, DAggerPhase
events = DAggerEvents() events = DAggerEvents()
assert events.phase == DAggerPhase.AUTONOMOUS assert events.phase == DAggerPhase.AUTONOMOUS
@@ -307,7 +306,7 @@ def test_dagger_full_transition_cycle():
def test_dagger_invalid_transition_ignored(): def test_dagger_invalid_transition_ignored():
from lerobot.rollout.strategies.dagger import DAggerEvents, DAggerPhase from lerobot.rollout.strategies import DAggerEvents, DAggerPhase
events = DAggerEvents() events = DAggerEvents()
events.request_transition("correction") # Not valid from AUTONOMOUS events.request_transition("correction") # Not valid from AUTONOMOUS
@@ -316,7 +315,7 @@ def test_dagger_invalid_transition_ignored():
def test_dagger_events_reset(): def test_dagger_events_reset():
from lerobot.rollout.strategies.dagger import DAggerEvents, DAggerPhase from lerobot.rollout.strategies import DAggerEvents, DAggerPhase
events = DAggerEvents() events = DAggerEvents()
events.request_transition("pause_resume") events.request_transition("pause_resume")
@@ -333,7 +332,7 @@ def test_dagger_events_reset():
def test_rollout_context_fields(): def test_rollout_context_fields():
from lerobot.rollout.context import RolloutContext from lerobot.rollout import RolloutContext
field_names = {f.name for f in dataclasses.fields(RolloutContext)} field_names = {f.name for f in dataclasses.fields(RolloutContext)}
assert field_names == {"runtime", "hardware", "policy", "processors", "data"} assert field_names == {"runtime", "hardware", "policy", "processors", "data"}