mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 02:59:50 +00:00
chore(rollout): multiple minor improvements
This commit is contained in:
+18
-11
@@ -86,7 +86,7 @@ def test_dagger_config_defaults():
|
||||
|
||||
|
||||
def test_inference_config_types():
|
||||
from lerobot.rollout.inference import RTCInferenceConfig, SyncInferenceConfig
|
||||
from lerobot.rollout import RTCInferenceConfig, SyncInferenceConfig
|
||||
|
||||
assert SyncInferenceConfig().type == "sync"
|
||||
|
||||
@@ -110,7 +110,7 @@ def test_sentry_config_defaults():
|
||||
|
||||
|
||||
def test_ring_buffer_append_and_eviction():
|
||||
from lerobot.rollout import RolloutRingBuffer
|
||||
from lerobot.rollout.ring_buffer import RolloutRingBuffer
|
||||
|
||||
buf = RolloutRingBuffer(max_seconds=0.5, max_memory_mb=100.0, fps=10.0)
|
||||
# max_frames = 5
|
||||
@@ -120,7 +120,7 @@ def test_ring_buffer_append_and_eviction():
|
||||
|
||||
|
||||
def test_ring_buffer_drain():
|
||||
from lerobot.rollout import RolloutRingBuffer
|
||||
from lerobot.rollout.ring_buffer import RolloutRingBuffer
|
||||
|
||||
buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0)
|
||||
for i in range(3):
|
||||
@@ -132,7 +132,7 @@ def test_ring_buffer_drain():
|
||||
|
||||
|
||||
def test_ring_buffer_clear():
|
||||
from lerobot.rollout import RolloutRingBuffer
|
||||
from lerobot.rollout.ring_buffer import RolloutRingBuffer
|
||||
|
||||
buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0)
|
||||
buf.append({"val": 1})
|
||||
@@ -142,7 +142,7 @@ def test_ring_buffer_clear():
|
||||
|
||||
|
||||
def test_ring_buffer_tensor_bytes():
|
||||
from lerobot.rollout import RolloutRingBuffer
|
||||
from lerobot.rollout.ring_buffer import RolloutRingBuffer
|
||||
|
||||
buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0)
|
||||
t = torch.zeros(100, dtype=torch.float32) # 400 bytes
|
||||
@@ -156,7 +156,7 @@ def test_ring_buffer_tensor_bytes():
|
||||
|
||||
|
||||
def test_thread_safe_robot_delegates():
|
||||
from lerobot.rollout import ThreadSafeRobot
|
||||
from lerobot.rollout.robot_wrapper import ThreadSafeRobot
|
||||
from tests.mocks.mock_robot import MockRobot, MockRobotConfig
|
||||
|
||||
robot = MockRobot(MockRobotConfig(n_motors=3))
|
||||
@@ -176,7 +176,7 @@ def test_thread_safe_robot_delegates():
|
||||
|
||||
|
||||
def test_thread_safe_robot_properties():
|
||||
from lerobot.rollout import ThreadSafeRobot
|
||||
from lerobot.rollout.robot_wrapper import ThreadSafeRobot
|
||||
from tests.mocks.mock_robot import MockRobot, MockRobotConfig
|
||||
|
||||
robot = MockRobot(MockRobotConfig(n_motors=3))
|
||||
@@ -198,8 +198,15 @@ def test_thread_safe_robot_properties():
|
||||
|
||||
|
||||
def test_create_strategy_dispatches():
|
||||
from lerobot.rollout import BaseStrategyConfig, DAggerStrategyConfig, SentryStrategyConfig
|
||||
from lerobot.rollout.strategies import BaseStrategy, DAggerStrategy, SentryStrategy, create_strategy
|
||||
from lerobot.rollout import (
|
||||
BaseStrategy,
|
||||
BaseStrategyConfig,
|
||||
DAggerStrategy,
|
||||
DAggerStrategyConfig,
|
||||
SentryStrategy,
|
||||
SentryStrategyConfig,
|
||||
create_strategy,
|
||||
)
|
||||
|
||||
assert isinstance(create_strategy(BaseStrategyConfig()), BaseStrategy)
|
||||
assert isinstance(create_strategy(SentryStrategyConfig()), SentryStrategy)
|
||||
@@ -207,7 +214,7 @@ def test_create_strategy_dispatches():
|
||||
|
||||
|
||||
def test_create_strategy_unknown_raises():
|
||||
from lerobot.rollout.strategies import create_strategy
|
||||
from lerobot.rollout import create_strategy
|
||||
|
||||
cfg = MagicMock()
|
||||
cfg.type = "bogus"
|
||||
@@ -221,7 +228,7 @@ def test_create_strategy_unknown_raises():
|
||||
|
||||
|
||||
def test_create_inference_engine_sync():
|
||||
from lerobot.rollout.inference import SyncInferenceConfig, SyncInferenceEngine, create_inference_engine
|
||||
from lerobot.rollout import SyncInferenceConfig, SyncInferenceEngine, create_inference_engine
|
||||
|
||||
engine = create_inference_engine(
|
||||
SyncInferenceConfig(),
|
||||
|
||||
Reference in New Issue
Block a user