chore(rollout): multiple minor improvements

This commit is contained in:
Steven Palma
2026-04-22 15:18:50 +02:00
parent bff2d50dc1
commit d70c3baf7c
9 changed files with 61 additions and 29 deletions
+18 -11
View File
@@ -86,7 +86,7 @@ def test_dagger_config_defaults():
def test_inference_config_types():
from lerobot.rollout.inference import RTCInferenceConfig, SyncInferenceConfig
from lerobot.rollout import RTCInferenceConfig, SyncInferenceConfig
assert SyncInferenceConfig().type == "sync"
@@ -110,7 +110,7 @@ def test_sentry_config_defaults():
def test_ring_buffer_append_and_eviction():
from lerobot.rollout import RolloutRingBuffer
from lerobot.rollout.ring_buffer import RolloutRingBuffer
buf = RolloutRingBuffer(max_seconds=0.5, max_memory_mb=100.0, fps=10.0)
# max_frames = 5
@@ -120,7 +120,7 @@ def test_ring_buffer_append_and_eviction():
def test_ring_buffer_drain():
from lerobot.rollout import RolloutRingBuffer
from lerobot.rollout.ring_buffer import RolloutRingBuffer
buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0)
for i in range(3):
@@ -132,7 +132,7 @@ def test_ring_buffer_drain():
def test_ring_buffer_clear():
from lerobot.rollout import RolloutRingBuffer
from lerobot.rollout.ring_buffer import RolloutRingBuffer
buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0)
buf.append({"val": 1})
@@ -142,7 +142,7 @@ def test_ring_buffer_clear():
def test_ring_buffer_tensor_bytes():
from lerobot.rollout import RolloutRingBuffer
from lerobot.rollout.ring_buffer import RolloutRingBuffer
buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0)
t = torch.zeros(100, dtype=torch.float32) # 400 bytes
@@ -156,7 +156,7 @@ def test_ring_buffer_tensor_bytes():
def test_thread_safe_robot_delegates():
from lerobot.rollout import ThreadSafeRobot
from lerobot.rollout.robot_wrapper import ThreadSafeRobot
from tests.mocks.mock_robot import MockRobot, MockRobotConfig
robot = MockRobot(MockRobotConfig(n_motors=3))
@@ -176,7 +176,7 @@ def test_thread_safe_robot_delegates():
def test_thread_safe_robot_properties():
from lerobot.rollout import ThreadSafeRobot
from lerobot.rollout.robot_wrapper import ThreadSafeRobot
from tests.mocks.mock_robot import MockRobot, MockRobotConfig
robot = MockRobot(MockRobotConfig(n_motors=3))
@@ -198,8 +198,15 @@ def test_thread_safe_robot_properties():
def test_create_strategy_dispatches():
from lerobot.rollout import BaseStrategyConfig, DAggerStrategyConfig, SentryStrategyConfig
from lerobot.rollout.strategies import BaseStrategy, DAggerStrategy, SentryStrategy, create_strategy
from lerobot.rollout import (
BaseStrategy,
BaseStrategyConfig,
DAggerStrategy,
DAggerStrategyConfig,
SentryStrategy,
SentryStrategyConfig,
create_strategy,
)
assert isinstance(create_strategy(BaseStrategyConfig()), BaseStrategy)
assert isinstance(create_strategy(SentryStrategyConfig()), SentryStrategy)
@@ -207,7 +214,7 @@ def test_create_strategy_dispatches():
def test_create_strategy_unknown_raises():
from lerobot.rollout.strategies import create_strategy
from lerobot.rollout import create_strategy
cfg = MagicMock()
cfg.type = "bogus"
@@ -221,7 +228,7 @@ def test_create_strategy_unknown_raises():
def test_create_inference_engine_sync():
from lerobot.rollout.inference import SyncInferenceConfig, SyncInferenceEngine, create_inference_engine
from lerobot.rollout import SyncInferenceConfig, SyncInferenceEngine, create_inference_engine
engine = create_inference_engine(
SyncInferenceConfig(),