mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
test: add dataset guard + fix imports
This commit is contained in:
@@ -227,10 +227,9 @@ See the [Real-Time Chunking](./rtc) guide for details on tuning RTC parameters.
|
||||
For custom deployments (e.g. with kinematics processors), use the rollout module API directly:
|
||||
|
||||
```python
|
||||
from lerobot.rollout.configs import BaseStrategyConfig, RolloutConfig
|
||||
from lerobot.rollout.context import build_rollout_context
|
||||
from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context
|
||||
from lerobot.rollout.inference import SyncInferenceConfig
|
||||
from lerobot.rollout.strategies.base import BaseStrategy
|
||||
from lerobot.rollout.strategies import BaseStrategy
|
||||
from lerobot.utils.process import ProcessSignalHandler
|
||||
|
||||
cfg = RolloutConfig(
|
||||
|
||||
@@ -24,10 +24,9 @@ recording, upload, and human-in-the-loop variants, see ``lerobot-rollout``.
|
||||
|
||||
from lerobot.configs import PreTrainedConfig
|
||||
from lerobot.robots.lekiwi import LeKiwiClientConfig
|
||||
from lerobot.rollout.configs import BaseStrategyConfig, RolloutConfig
|
||||
from lerobot.rollout.context import build_rollout_context
|
||||
from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context
|
||||
from lerobot.rollout.inference import SyncInferenceConfig
|
||||
from lerobot.rollout.strategies.base import BaseStrategy
|
||||
from lerobot.rollout.strategies import BaseStrategy
|
||||
from lerobot.utils.process import ProcessSignalHandler
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
|
||||
@@ -40,10 +40,9 @@ from lerobot.robots.so_follower.robot_kinematic_processor import (
|
||||
ForwardKinematicsJointsToEE,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.rollout.configs import BaseStrategyConfig, RolloutConfig
|
||||
from lerobot.rollout.context import build_rollout_context
|
||||
from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context
|
||||
from lerobot.rollout.inference import SyncInferenceConfig
|
||||
from lerobot.rollout.strategies.base import BaseStrategy
|
||||
from lerobot.rollout.strategies import BaseStrategy
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.process import ProcessSignalHandler
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
@@ -38,10 +38,9 @@ from lerobot.robots.so_follower.robot_kinematic_processor import (
|
||||
ForwardKinematicsJointsToEE,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.rollout.configs import BaseStrategyConfig, RolloutConfig
|
||||
from lerobot.rollout.context import build_rollout_context
|
||||
from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context
|
||||
from lerobot.rollout.inference import SyncInferenceConfig
|
||||
from lerobot.rollout.strategies.base import BaseStrategy
|
||||
from lerobot.rollout.strategies import BaseStrategy
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.process import ProcessSignalHandler
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
@@ -14,6 +14,10 @@
|
||||
|
||||
"""Policy deployment engine with pluggable rollout strategies."""
|
||||
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
require_package("datasets", extra="dataset")
|
||||
|
||||
from .configs import (
|
||||
BaseStrategyConfig,
|
||||
DAggerKeyboardConfig,
|
||||
@@ -25,7 +29,15 @@ from .configs import (
|
||||
RolloutStrategyConfig,
|
||||
SentryStrategyConfig,
|
||||
)
|
||||
from .context import RolloutContext, build_rollout_context
|
||||
from .context import (
|
||||
DatasetContext,
|
||||
HardwareContext,
|
||||
PolicyContext,
|
||||
ProcessorContext,
|
||||
RolloutContext,
|
||||
RuntimeContext,
|
||||
build_rollout_context,
|
||||
)
|
||||
from .inference import (
|
||||
InferenceEngine,
|
||||
InferenceEngineConfig,
|
||||
@@ -44,17 +56,22 @@ __all__ = [
|
||||
"DAggerKeyboardConfig",
|
||||
"DAggerPedalConfig",
|
||||
"DAggerStrategyConfig",
|
||||
"DatasetContext",
|
||||
"DatasetRecordConfig",
|
||||
"HardwareContext",
|
||||
"HighlightStrategyConfig",
|
||||
"InferenceEngine",
|
||||
"InferenceEngineConfig",
|
||||
"PolicyContext",
|
||||
"ProcessorContext",
|
||||
"RTCInferenceConfig",
|
||||
"RTCInferenceEngine",
|
||||
"RolloutConfig",
|
||||
"RolloutContext",
|
||||
"DatasetRecordConfig",
|
||||
"RolloutRingBuffer",
|
||||
"RolloutStrategy",
|
||||
"RolloutStrategyConfig",
|
||||
"RuntimeContext",
|
||||
"SentryStrategyConfig",
|
||||
"SyncInferenceConfig",
|
||||
"SyncInferenceEngine",
|
||||
|
||||
@@ -14,11 +14,21 @@
|
||||
|
||||
"""Rollout strategies — public API re-exports."""
|
||||
|
||||
from .base import BaseStrategy
|
||||
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
|
||||
from .dagger import DAggerEvents, DAggerPhase, DAggerStrategy
|
||||
from .factory import create_strategy
|
||||
from .highlight import HighlightStrategy
|
||||
from .sentry import SentryStrategy
|
||||
|
||||
__all__ = [
|
||||
"BaseStrategy",
|
||||
"DAggerEvents",
|
||||
"DAggerPhase",
|
||||
"DAggerStrategy",
|
||||
"HighlightStrategy",
|
||||
"RolloutStrategy",
|
||||
"SentryStrategy",
|
||||
"create_strategy",
|
||||
"estimate_max_episode_seconds",
|
||||
"safe_push_to_hub",
|
||||
|
||||
@@ -25,7 +25,7 @@ from .highlight import HighlightStrategy
|
||||
from .sentry import SentryStrategy
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.rollout.configs import RolloutStrategyConfig
|
||||
from lerobot.rollout import RolloutStrategyConfig
|
||||
|
||||
|
||||
def create_strategy(config: RolloutStrategyConfig) -> RolloutStrategy:
|
||||
|
||||
+18
-19
@@ -22,6 +22,8 @@ from unittest.mock import MagicMock
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Import smoke tests
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -54,7 +56,7 @@ def test_strategies_submodule_imports():
|
||||
|
||||
|
||||
def test_strategy_config_types():
|
||||
from lerobot.rollout.configs import (
|
||||
from lerobot.rollout import (
|
||||
BaseStrategyConfig,
|
||||
DAggerStrategyConfig,
|
||||
HighlightStrategyConfig,
|
||||
@@ -68,14 +70,14 @@ def test_strategy_config_types():
|
||||
|
||||
|
||||
def test_dagger_config_invalid_input_device():
|
||||
from lerobot.rollout.configs import DAggerStrategyConfig
|
||||
from lerobot.rollout import DAggerStrategyConfig
|
||||
|
||||
with pytest.raises(ValueError, match="input_device must be 'keyboard' or 'pedal'"):
|
||||
DAggerStrategyConfig(input_device="joystick")
|
||||
|
||||
|
||||
def test_dagger_config_defaults():
|
||||
from lerobot.rollout.configs import DAggerStrategyConfig
|
||||
from lerobot.rollout import DAggerStrategyConfig
|
||||
|
||||
cfg = DAggerStrategyConfig()
|
||||
assert cfg.num_episodes == 10
|
||||
@@ -95,7 +97,7 @@ def test_inference_config_types():
|
||||
|
||||
|
||||
def test_sentry_config_defaults():
|
||||
from lerobot.rollout.configs import SentryStrategyConfig
|
||||
from lerobot.rollout import SentryStrategyConfig
|
||||
|
||||
cfg = SentryStrategyConfig()
|
||||
assert cfg.upload_every_n_episodes == 5
|
||||
@@ -108,7 +110,7 @@ def test_sentry_config_defaults():
|
||||
|
||||
|
||||
def test_ring_buffer_append_and_eviction():
|
||||
from lerobot.rollout.ring_buffer import RolloutRingBuffer
|
||||
from lerobot.rollout import RolloutRingBuffer
|
||||
|
||||
buf = RolloutRingBuffer(max_seconds=0.5, max_memory_mb=100.0, fps=10.0)
|
||||
# max_frames = 5
|
||||
@@ -118,7 +120,7 @@ def test_ring_buffer_append_and_eviction():
|
||||
|
||||
|
||||
def test_ring_buffer_drain():
|
||||
from lerobot.rollout.ring_buffer import RolloutRingBuffer
|
||||
from lerobot.rollout import RolloutRingBuffer
|
||||
|
||||
buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0)
|
||||
for i in range(3):
|
||||
@@ -130,7 +132,7 @@ def test_ring_buffer_drain():
|
||||
|
||||
|
||||
def test_ring_buffer_clear():
|
||||
from lerobot.rollout.ring_buffer import RolloutRingBuffer
|
||||
from lerobot.rollout import RolloutRingBuffer
|
||||
|
||||
buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0)
|
||||
buf.append({"val": 1})
|
||||
@@ -140,7 +142,7 @@ def test_ring_buffer_clear():
|
||||
|
||||
|
||||
def test_ring_buffer_tensor_bytes():
|
||||
from lerobot.rollout.ring_buffer import RolloutRingBuffer
|
||||
from lerobot.rollout import RolloutRingBuffer
|
||||
|
||||
buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0)
|
||||
t = torch.zeros(100, dtype=torch.float32) # 400 bytes
|
||||
@@ -154,7 +156,7 @@ def test_ring_buffer_tensor_bytes():
|
||||
|
||||
|
||||
def test_thread_safe_robot_delegates():
|
||||
from lerobot.rollout.robot_wrapper import ThreadSafeRobot
|
||||
from lerobot.rollout import ThreadSafeRobot
|
||||
from tests.mocks.mock_robot import MockRobot, MockRobotConfig
|
||||
|
||||
robot = MockRobot(MockRobotConfig(n_motors=3))
|
||||
@@ -174,7 +176,7 @@ def test_thread_safe_robot_delegates():
|
||||
|
||||
|
||||
def test_thread_safe_robot_properties():
|
||||
from lerobot.rollout.robot_wrapper import ThreadSafeRobot
|
||||
from lerobot.rollout import ThreadSafeRobot
|
||||
from tests.mocks.mock_robot import MockRobot, MockRobotConfig
|
||||
|
||||
robot = MockRobot(MockRobotConfig(n_motors=3))
|
||||
@@ -196,11 +198,8 @@ def test_thread_safe_robot_properties():
|
||||
|
||||
|
||||
def test_create_strategy_dispatches():
|
||||
from lerobot.rollout.configs import BaseStrategyConfig, DAggerStrategyConfig, SentryStrategyConfig
|
||||
from lerobot.rollout.strategies import create_strategy
|
||||
from lerobot.rollout.strategies.base import BaseStrategy
|
||||
from lerobot.rollout.strategies.dagger import DAggerStrategy
|
||||
from lerobot.rollout.strategies.sentry import SentryStrategy
|
||||
from lerobot.rollout import BaseStrategyConfig, DAggerStrategyConfig, SentryStrategyConfig
|
||||
from lerobot.rollout.strategies import BaseStrategy, DAggerStrategy, SentryStrategy, create_strategy
|
||||
|
||||
assert isinstance(create_strategy(BaseStrategyConfig()), BaseStrategy)
|
||||
assert isinstance(create_strategy(SentryStrategyConfig()), SentryStrategy)
|
||||
@@ -280,7 +279,7 @@ def test_safe_push_to_hub():
|
||||
|
||||
|
||||
def test_dagger_full_transition_cycle():
|
||||
from lerobot.rollout.strategies.dagger import DAggerEvents, DAggerPhase
|
||||
from lerobot.rollout.strategies import DAggerEvents, DAggerPhase
|
||||
|
||||
events = DAggerEvents()
|
||||
assert events.phase == DAggerPhase.AUTONOMOUS
|
||||
@@ -307,7 +306,7 @@ def test_dagger_full_transition_cycle():
|
||||
|
||||
|
||||
def test_dagger_invalid_transition_ignored():
|
||||
from lerobot.rollout.strategies.dagger import DAggerEvents, DAggerPhase
|
||||
from lerobot.rollout.strategies import DAggerEvents, DAggerPhase
|
||||
|
||||
events = DAggerEvents()
|
||||
events.request_transition("correction") # Not valid from AUTONOMOUS
|
||||
@@ -316,7 +315,7 @@ def test_dagger_invalid_transition_ignored():
|
||||
|
||||
|
||||
def test_dagger_events_reset():
|
||||
from lerobot.rollout.strategies.dagger import DAggerEvents, DAggerPhase
|
||||
from lerobot.rollout.strategies import DAggerEvents, DAggerPhase
|
||||
|
||||
events = DAggerEvents()
|
||||
events.request_transition("pause_resume")
|
||||
@@ -333,7 +332,7 @@ def test_dagger_events_reset():
|
||||
|
||||
|
||||
def test_rollout_context_fields():
|
||||
from lerobot.rollout.context import RolloutContext
|
||||
from lerobot.rollout import RolloutContext
|
||||
|
||||
field_names = {f.name for f in dataclasses.fields(RolloutContext)}
|
||||
assert field_names == {"runtime", "hardware", "policy", "processors", "data"}
|
||||
|
||||
Reference in New Issue
Block a user