mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 12:09:42 +00:00
test: add dataset guard + fix imports
This commit is contained in:
@@ -227,10 +227,9 @@ See the [Real-Time Chunking](./rtc) guide for details on tuning RTC parameters.
|
|||||||
For custom deployments (e.g. with kinematics processors), use the rollout module API directly:
|
For custom deployments (e.g. with kinematics processors), use the rollout module API directly:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from lerobot.rollout.configs import BaseStrategyConfig, RolloutConfig
|
from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context
|
||||||
from lerobot.rollout.context import build_rollout_context
|
|
||||||
from lerobot.rollout.inference import SyncInferenceConfig
|
from lerobot.rollout.inference import SyncInferenceConfig
|
||||||
from lerobot.rollout.strategies.base import BaseStrategy
|
from lerobot.rollout.strategies import BaseStrategy
|
||||||
from lerobot.utils.process import ProcessSignalHandler
|
from lerobot.utils.process import ProcessSignalHandler
|
||||||
|
|
||||||
cfg = RolloutConfig(
|
cfg = RolloutConfig(
|
||||||
|
|||||||
@@ -24,10 +24,9 @@ recording, upload, and human-in-the-loop variants, see ``lerobot-rollout``.
|
|||||||
|
|
||||||
from lerobot.configs import PreTrainedConfig
|
from lerobot.configs import PreTrainedConfig
|
||||||
from lerobot.robots.lekiwi import LeKiwiClientConfig
|
from lerobot.robots.lekiwi import LeKiwiClientConfig
|
||||||
from lerobot.rollout.configs import BaseStrategyConfig, RolloutConfig
|
from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context
|
||||||
from lerobot.rollout.context import build_rollout_context
|
|
||||||
from lerobot.rollout.inference import SyncInferenceConfig
|
from lerobot.rollout.inference import SyncInferenceConfig
|
||||||
from lerobot.rollout.strategies.base import BaseStrategy
|
from lerobot.rollout.strategies import BaseStrategy
|
||||||
from lerobot.utils.process import ProcessSignalHandler
|
from lerobot.utils.process import ProcessSignalHandler
|
||||||
from lerobot.utils.utils import init_logging
|
from lerobot.utils.utils import init_logging
|
||||||
|
|
||||||
|
|||||||
@@ -40,10 +40,9 @@ from lerobot.robots.so_follower.robot_kinematic_processor import (
|
|||||||
ForwardKinematicsJointsToEE,
|
ForwardKinematicsJointsToEE,
|
||||||
InverseKinematicsEEToJoints,
|
InverseKinematicsEEToJoints,
|
||||||
)
|
)
|
||||||
from lerobot.rollout.configs import BaseStrategyConfig, RolloutConfig
|
from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context
|
||||||
from lerobot.rollout.context import build_rollout_context
|
|
||||||
from lerobot.rollout.inference import SyncInferenceConfig
|
from lerobot.rollout.inference import SyncInferenceConfig
|
||||||
from lerobot.rollout.strategies.base import BaseStrategy
|
from lerobot.rollout.strategies import BaseStrategy
|
||||||
from lerobot.types import RobotAction, RobotObservation
|
from lerobot.types import RobotAction, RobotObservation
|
||||||
from lerobot.utils.process import ProcessSignalHandler
|
from lerobot.utils.process import ProcessSignalHandler
|
||||||
from lerobot.utils.utils import init_logging
|
from lerobot.utils.utils import init_logging
|
||||||
|
|||||||
@@ -38,10 +38,9 @@ from lerobot.robots.so_follower.robot_kinematic_processor import (
|
|||||||
ForwardKinematicsJointsToEE,
|
ForwardKinematicsJointsToEE,
|
||||||
InverseKinematicsEEToJoints,
|
InverseKinematicsEEToJoints,
|
||||||
)
|
)
|
||||||
from lerobot.rollout.configs import BaseStrategyConfig, RolloutConfig
|
from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context
|
||||||
from lerobot.rollout.context import build_rollout_context
|
|
||||||
from lerobot.rollout.inference import SyncInferenceConfig
|
from lerobot.rollout.inference import SyncInferenceConfig
|
||||||
from lerobot.rollout.strategies.base import BaseStrategy
|
from lerobot.rollout.strategies import BaseStrategy
|
||||||
from lerobot.types import RobotAction, RobotObservation
|
from lerobot.types import RobotAction, RobotObservation
|
||||||
from lerobot.utils.process import ProcessSignalHandler
|
from lerobot.utils.process import ProcessSignalHandler
|
||||||
from lerobot.utils.utils import init_logging
|
from lerobot.utils.utils import init_logging
|
||||||
|
|||||||
@@ -14,6 +14,10 @@
|
|||||||
|
|
||||||
"""Policy deployment engine with pluggable rollout strategies."""
|
"""Policy deployment engine with pluggable rollout strategies."""
|
||||||
|
|
||||||
|
from lerobot.utils.import_utils import require_package
|
||||||
|
|
||||||
|
require_package("datasets", extra="dataset")
|
||||||
|
|
||||||
from .configs import (
|
from .configs import (
|
||||||
BaseStrategyConfig,
|
BaseStrategyConfig,
|
||||||
DAggerKeyboardConfig,
|
DAggerKeyboardConfig,
|
||||||
@@ -25,7 +29,15 @@ from .configs import (
|
|||||||
RolloutStrategyConfig,
|
RolloutStrategyConfig,
|
||||||
SentryStrategyConfig,
|
SentryStrategyConfig,
|
||||||
)
|
)
|
||||||
from .context import RolloutContext, build_rollout_context
|
from .context import (
|
||||||
|
DatasetContext,
|
||||||
|
HardwareContext,
|
||||||
|
PolicyContext,
|
||||||
|
ProcessorContext,
|
||||||
|
RolloutContext,
|
||||||
|
RuntimeContext,
|
||||||
|
build_rollout_context,
|
||||||
|
)
|
||||||
from .inference import (
|
from .inference import (
|
||||||
InferenceEngine,
|
InferenceEngine,
|
||||||
InferenceEngineConfig,
|
InferenceEngineConfig,
|
||||||
@@ -44,17 +56,22 @@ __all__ = [
|
|||||||
"DAggerKeyboardConfig",
|
"DAggerKeyboardConfig",
|
||||||
"DAggerPedalConfig",
|
"DAggerPedalConfig",
|
||||||
"DAggerStrategyConfig",
|
"DAggerStrategyConfig",
|
||||||
|
"DatasetContext",
|
||||||
|
"DatasetRecordConfig",
|
||||||
|
"HardwareContext",
|
||||||
"HighlightStrategyConfig",
|
"HighlightStrategyConfig",
|
||||||
"InferenceEngine",
|
"InferenceEngine",
|
||||||
"InferenceEngineConfig",
|
"InferenceEngineConfig",
|
||||||
|
"PolicyContext",
|
||||||
|
"ProcessorContext",
|
||||||
"RTCInferenceConfig",
|
"RTCInferenceConfig",
|
||||||
"RTCInferenceEngine",
|
"RTCInferenceEngine",
|
||||||
"RolloutConfig",
|
"RolloutConfig",
|
||||||
"RolloutContext",
|
"RolloutContext",
|
||||||
"DatasetRecordConfig",
|
|
||||||
"RolloutRingBuffer",
|
"RolloutRingBuffer",
|
||||||
"RolloutStrategy",
|
"RolloutStrategy",
|
||||||
"RolloutStrategyConfig",
|
"RolloutStrategyConfig",
|
||||||
|
"RuntimeContext",
|
||||||
"SentryStrategyConfig",
|
"SentryStrategyConfig",
|
||||||
"SyncInferenceConfig",
|
"SyncInferenceConfig",
|
||||||
"SyncInferenceEngine",
|
"SyncInferenceEngine",
|
||||||
|
|||||||
@@ -14,11 +14,21 @@
|
|||||||
|
|
||||||
"""Rollout strategies — public API re-exports."""
|
"""Rollout strategies — public API re-exports."""
|
||||||
|
|
||||||
|
from .base import BaseStrategy
|
||||||
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
|
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
|
||||||
|
from .dagger import DAggerEvents, DAggerPhase, DAggerStrategy
|
||||||
from .factory import create_strategy
|
from .factory import create_strategy
|
||||||
|
from .highlight import HighlightStrategy
|
||||||
|
from .sentry import SentryStrategy
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"BaseStrategy",
|
||||||
|
"DAggerEvents",
|
||||||
|
"DAggerPhase",
|
||||||
|
"DAggerStrategy",
|
||||||
|
"HighlightStrategy",
|
||||||
"RolloutStrategy",
|
"RolloutStrategy",
|
||||||
|
"SentryStrategy",
|
||||||
"create_strategy",
|
"create_strategy",
|
||||||
"estimate_max_episode_seconds",
|
"estimate_max_episode_seconds",
|
||||||
"safe_push_to_hub",
|
"safe_push_to_hub",
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ from .highlight import HighlightStrategy
|
|||||||
from .sentry import SentryStrategy
|
from .sentry import SentryStrategy
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from lerobot.rollout.configs import RolloutStrategyConfig
|
from lerobot.rollout import RolloutStrategyConfig
|
||||||
|
|
||||||
|
|
||||||
def create_strategy(config: RolloutStrategyConfig) -> RolloutStrategy:
|
def create_strategy(config: RolloutStrategyConfig) -> RolloutStrategy:
|
||||||
|
|||||||
+18
-19
@@ -22,6 +22,8 @@ from unittest.mock import MagicMock
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Import smoke tests
|
# Import smoke tests
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -54,7 +56,7 @@ def test_strategies_submodule_imports():
|
|||||||
|
|
||||||
|
|
||||||
def test_strategy_config_types():
|
def test_strategy_config_types():
|
||||||
from lerobot.rollout.configs import (
|
from lerobot.rollout import (
|
||||||
BaseStrategyConfig,
|
BaseStrategyConfig,
|
||||||
DAggerStrategyConfig,
|
DAggerStrategyConfig,
|
||||||
HighlightStrategyConfig,
|
HighlightStrategyConfig,
|
||||||
@@ -68,14 +70,14 @@ def test_strategy_config_types():
|
|||||||
|
|
||||||
|
|
||||||
def test_dagger_config_invalid_input_device():
|
def test_dagger_config_invalid_input_device():
|
||||||
from lerobot.rollout.configs import DAggerStrategyConfig
|
from lerobot.rollout import DAggerStrategyConfig
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="input_device must be 'keyboard' or 'pedal'"):
|
with pytest.raises(ValueError, match="input_device must be 'keyboard' or 'pedal'"):
|
||||||
DAggerStrategyConfig(input_device="joystick")
|
DAggerStrategyConfig(input_device="joystick")
|
||||||
|
|
||||||
|
|
||||||
def test_dagger_config_defaults():
|
def test_dagger_config_defaults():
|
||||||
from lerobot.rollout.configs import DAggerStrategyConfig
|
from lerobot.rollout import DAggerStrategyConfig
|
||||||
|
|
||||||
cfg = DAggerStrategyConfig()
|
cfg = DAggerStrategyConfig()
|
||||||
assert cfg.num_episodes == 10
|
assert cfg.num_episodes == 10
|
||||||
@@ -95,7 +97,7 @@ def test_inference_config_types():
|
|||||||
|
|
||||||
|
|
||||||
def test_sentry_config_defaults():
|
def test_sentry_config_defaults():
|
||||||
from lerobot.rollout.configs import SentryStrategyConfig
|
from lerobot.rollout import SentryStrategyConfig
|
||||||
|
|
||||||
cfg = SentryStrategyConfig()
|
cfg = SentryStrategyConfig()
|
||||||
assert cfg.upload_every_n_episodes == 5
|
assert cfg.upload_every_n_episodes == 5
|
||||||
@@ -108,7 +110,7 @@ def test_sentry_config_defaults():
|
|||||||
|
|
||||||
|
|
||||||
def test_ring_buffer_append_and_eviction():
|
def test_ring_buffer_append_and_eviction():
|
||||||
from lerobot.rollout.ring_buffer import RolloutRingBuffer
|
from lerobot.rollout import RolloutRingBuffer
|
||||||
|
|
||||||
buf = RolloutRingBuffer(max_seconds=0.5, max_memory_mb=100.0, fps=10.0)
|
buf = RolloutRingBuffer(max_seconds=0.5, max_memory_mb=100.0, fps=10.0)
|
||||||
# max_frames = 5
|
# max_frames = 5
|
||||||
@@ -118,7 +120,7 @@ def test_ring_buffer_append_and_eviction():
|
|||||||
|
|
||||||
|
|
||||||
def test_ring_buffer_drain():
|
def test_ring_buffer_drain():
|
||||||
from lerobot.rollout.ring_buffer import RolloutRingBuffer
|
from lerobot.rollout import RolloutRingBuffer
|
||||||
|
|
||||||
buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0)
|
buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0)
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
@@ -130,7 +132,7 @@ def test_ring_buffer_drain():
|
|||||||
|
|
||||||
|
|
||||||
def test_ring_buffer_clear():
|
def test_ring_buffer_clear():
|
||||||
from lerobot.rollout.ring_buffer import RolloutRingBuffer
|
from lerobot.rollout import RolloutRingBuffer
|
||||||
|
|
||||||
buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0)
|
buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0)
|
||||||
buf.append({"val": 1})
|
buf.append({"val": 1})
|
||||||
@@ -140,7 +142,7 @@ def test_ring_buffer_clear():
|
|||||||
|
|
||||||
|
|
||||||
def test_ring_buffer_tensor_bytes():
|
def test_ring_buffer_tensor_bytes():
|
||||||
from lerobot.rollout.ring_buffer import RolloutRingBuffer
|
from lerobot.rollout import RolloutRingBuffer
|
||||||
|
|
||||||
buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0)
|
buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0)
|
||||||
t = torch.zeros(100, dtype=torch.float32) # 400 bytes
|
t = torch.zeros(100, dtype=torch.float32) # 400 bytes
|
||||||
@@ -154,7 +156,7 @@ def test_ring_buffer_tensor_bytes():
|
|||||||
|
|
||||||
|
|
||||||
def test_thread_safe_robot_delegates():
|
def test_thread_safe_robot_delegates():
|
||||||
from lerobot.rollout.robot_wrapper import ThreadSafeRobot
|
from lerobot.rollout import ThreadSafeRobot
|
||||||
from tests.mocks.mock_robot import MockRobot, MockRobotConfig
|
from tests.mocks.mock_robot import MockRobot, MockRobotConfig
|
||||||
|
|
||||||
robot = MockRobot(MockRobotConfig(n_motors=3))
|
robot = MockRobot(MockRobotConfig(n_motors=3))
|
||||||
@@ -174,7 +176,7 @@ def test_thread_safe_robot_delegates():
|
|||||||
|
|
||||||
|
|
||||||
def test_thread_safe_robot_properties():
|
def test_thread_safe_robot_properties():
|
||||||
from lerobot.rollout.robot_wrapper import ThreadSafeRobot
|
from lerobot.rollout import ThreadSafeRobot
|
||||||
from tests.mocks.mock_robot import MockRobot, MockRobotConfig
|
from tests.mocks.mock_robot import MockRobot, MockRobotConfig
|
||||||
|
|
||||||
robot = MockRobot(MockRobotConfig(n_motors=3))
|
robot = MockRobot(MockRobotConfig(n_motors=3))
|
||||||
@@ -196,11 +198,8 @@ def test_thread_safe_robot_properties():
|
|||||||
|
|
||||||
|
|
||||||
def test_create_strategy_dispatches():
|
def test_create_strategy_dispatches():
|
||||||
from lerobot.rollout.configs import BaseStrategyConfig, DAggerStrategyConfig, SentryStrategyConfig
|
from lerobot.rollout import BaseStrategyConfig, DAggerStrategyConfig, SentryStrategyConfig
|
||||||
from lerobot.rollout.strategies import create_strategy
|
from lerobot.rollout.strategies import BaseStrategy, DAggerStrategy, SentryStrategy, create_strategy
|
||||||
from lerobot.rollout.strategies.base import BaseStrategy
|
|
||||||
from lerobot.rollout.strategies.dagger import DAggerStrategy
|
|
||||||
from lerobot.rollout.strategies.sentry import SentryStrategy
|
|
||||||
|
|
||||||
assert isinstance(create_strategy(BaseStrategyConfig()), BaseStrategy)
|
assert isinstance(create_strategy(BaseStrategyConfig()), BaseStrategy)
|
||||||
assert isinstance(create_strategy(SentryStrategyConfig()), SentryStrategy)
|
assert isinstance(create_strategy(SentryStrategyConfig()), SentryStrategy)
|
||||||
@@ -280,7 +279,7 @@ def test_safe_push_to_hub():
|
|||||||
|
|
||||||
|
|
||||||
def test_dagger_full_transition_cycle():
|
def test_dagger_full_transition_cycle():
|
||||||
from lerobot.rollout.strategies.dagger import DAggerEvents, DAggerPhase
|
from lerobot.rollout.strategies import DAggerEvents, DAggerPhase
|
||||||
|
|
||||||
events = DAggerEvents()
|
events = DAggerEvents()
|
||||||
assert events.phase == DAggerPhase.AUTONOMOUS
|
assert events.phase == DAggerPhase.AUTONOMOUS
|
||||||
@@ -307,7 +306,7 @@ def test_dagger_full_transition_cycle():
|
|||||||
|
|
||||||
|
|
||||||
def test_dagger_invalid_transition_ignored():
|
def test_dagger_invalid_transition_ignored():
|
||||||
from lerobot.rollout.strategies.dagger import DAggerEvents, DAggerPhase
|
from lerobot.rollout.strategies import DAggerEvents, DAggerPhase
|
||||||
|
|
||||||
events = DAggerEvents()
|
events = DAggerEvents()
|
||||||
events.request_transition("correction") # Not valid from AUTONOMOUS
|
events.request_transition("correction") # Not valid from AUTONOMOUS
|
||||||
@@ -316,7 +315,7 @@ def test_dagger_invalid_transition_ignored():
|
|||||||
|
|
||||||
|
|
||||||
def test_dagger_events_reset():
|
def test_dagger_events_reset():
|
||||||
from lerobot.rollout.strategies.dagger import DAggerEvents, DAggerPhase
|
from lerobot.rollout.strategies import DAggerEvents, DAggerPhase
|
||||||
|
|
||||||
events = DAggerEvents()
|
events = DAggerEvents()
|
||||||
events.request_transition("pause_resume")
|
events.request_transition("pause_resume")
|
||||||
@@ -333,7 +332,7 @@ def test_dagger_events_reset():
|
|||||||
|
|
||||||
|
|
||||||
def test_rollout_context_fields():
|
def test_rollout_context_fields():
|
||||||
from lerobot.rollout.context import RolloutContext
|
from lerobot.rollout import RolloutContext
|
||||||
|
|
||||||
field_names = {f.name for f in dataclasses.fields(RolloutContext)}
|
field_names = {f.name for f in dataclasses.fields(RolloutContext)}
|
||||||
assert field_names == {"runtime", "hardware", "policy", "processors", "data"}
|
assert field_names == {"runtime", "hardware", "policy", "processors", "data"}
|
||||||
|
|||||||
Reference in New Issue
Block a user