Files
lerobot/tests/test_rollout.py
Steven Palma ca87ccd941 feat(rollout): decouple policy deployment from data recording with new lerobot-rollout CLI (#3413)
* feat(scripts): lerobot-rollout

* fix(rollout) require dataset in dagger + use duration too

* fix(docs): dagger num_episodes

* test(rollout): fix expectations

* fix(rollout): features check

* fix(rollout): device and task propagation + feature pos + warn fps + move rename_map config

* docs(rollout): edit rename_map instructions

* chore(rollout): multiple minor improvements

* chore(rollout): address coments + minor improvements

* fix(rollout): enable default

* fix(tests): default value RTCConfig

* fix(rollout): robot_observation_processor and notify_observation at policy frequency instead of interpolator rate

Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>

* fix(rollout): prevent relativeactions with sync inference engine

Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>

* fix(rollout): rtc reanchor to non normalized state

Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>

* fix(rollout): fixing the episode length to use hwc (#3469)

also reducing default length to 5 minutes

* feat(rollout): go back to initial position is now a config

* fix(rollout): properly propagating video_files_size_in_mb to lerobot_dataset (#3470)

* chore(rollout): note about dagger correction stage

* chore(docs): update comments and docstring

* fix(test): move rtc relative out of rollout module

* fix(rollout): address the review comments

---------

Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
Co-authored-by: Maxime Ellerbach <maxime.ellerbach@huggingface.co>
2026-04-28 00:57:35 +02:00

346 lines
11 KiB
Python

# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Minimal tests for the rollout module's public API."""
from __future__ import annotations
import dataclasses
from unittest.mock import MagicMock
import pytest
import torch
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
# ---------------------------------------------------------------------------
# Import smoke tests
# ---------------------------------------------------------------------------
def test_rollout_top_level_imports():
import lerobot.rollout
for name in lerobot.rollout.__all__:
assert hasattr(lerobot.rollout, name), f"Missing export: {name}"
def test_inference_submodule_imports():
import lerobot.rollout.inference
for name in lerobot.rollout.inference.__all__:
assert hasattr(lerobot.rollout.inference, name), f"Missing export: {name}"
def test_strategies_submodule_imports():
import lerobot.rollout.strategies
for name in lerobot.rollout.strategies.__all__:
assert hasattr(lerobot.rollout.strategies, name), f"Missing export: {name}"
# ---------------------------------------------------------------------------
# Config tests
# ---------------------------------------------------------------------------
def test_strategy_config_types():
from lerobot.rollout import (
BaseStrategyConfig,
DAggerStrategyConfig,
HighlightStrategyConfig,
SentryStrategyConfig,
)
assert BaseStrategyConfig().type == "base"
assert SentryStrategyConfig().type == "sentry"
assert HighlightStrategyConfig().type == "highlight"
assert DAggerStrategyConfig().type == "dagger"
def test_dagger_config_invalid_input_device():
from lerobot.rollout import DAggerStrategyConfig
with pytest.raises(ValueError, match="input_device must be 'keyboard' or 'pedal'"):
DAggerStrategyConfig(input_device="joystick")
def test_dagger_config_defaults():
from lerobot.rollout import DAggerStrategyConfig
cfg = DAggerStrategyConfig()
assert cfg.num_episodes is None
assert cfg.record_autonomous is False
assert cfg.input_device == "keyboard"
def test_inference_config_types():
from lerobot.rollout import RTCInferenceConfig, SyncInferenceConfig
assert SyncInferenceConfig().type == "sync"
rtc = RTCInferenceConfig()
assert rtc.type == "rtc"
assert rtc.queue_threshold == 30
assert rtc.rtc is not None
def test_sentry_config_defaults():
from lerobot.rollout import SentryStrategyConfig
cfg = SentryStrategyConfig()
assert cfg.upload_every_n_episodes == 5
assert cfg.target_video_file_size_mb is None
# ---------------------------------------------------------------------------
# RolloutRingBuffer
# ---------------------------------------------------------------------------
def test_ring_buffer_append_and_eviction():
from lerobot.rollout.ring_buffer import RolloutRingBuffer
buf = RolloutRingBuffer(max_seconds=0.5, max_memory_mb=100.0, fps=10.0)
# max_frames = 5
for i in range(8):
buf.append({"val": i})
assert len(buf) == 5
def test_ring_buffer_drain():
from lerobot.rollout.ring_buffer import RolloutRingBuffer
buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0)
for i in range(3):
buf.append({"val": i})
frames = buf.drain()
assert len(frames) == 3
assert len(buf) == 0
assert buf.estimated_bytes == 0
def test_ring_buffer_clear():
from lerobot.rollout.ring_buffer import RolloutRingBuffer
buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0)
buf.append({"val": 1})
buf.clear()
assert len(buf) == 0
assert buf.estimated_bytes == 0
def test_ring_buffer_tensor_bytes():
from lerobot.rollout.ring_buffer import RolloutRingBuffer
buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0)
t = torch.zeros(100, dtype=torch.float32) # 400 bytes
buf.append({"tensor": t})
assert buf.estimated_bytes >= 400
# ---------------------------------------------------------------------------
# ThreadSafeRobot
# ---------------------------------------------------------------------------
def test_thread_safe_robot_delegates():
from lerobot.rollout.robot_wrapper import ThreadSafeRobot
from tests.mocks.mock_robot import MockRobot, MockRobotConfig
robot = MockRobot(MockRobotConfig(n_motors=3))
robot.connect()
wrapper = ThreadSafeRobot(robot)
obs = wrapper.get_observation()
assert "motor_1.pos" in obs
assert "motor_2.pos" in obs
assert "motor_3.pos" in obs
action = {"motor_1.pos": 0.0, "motor_2.pos": 1.0, "motor_3.pos": 2.0}
result = wrapper.send_action(action)
assert result == action
robot.disconnect()
def test_thread_safe_robot_properties():
from lerobot.rollout.robot_wrapper import ThreadSafeRobot
from tests.mocks.mock_robot import MockRobot, MockRobotConfig
robot = MockRobot(MockRobotConfig(n_motors=3))
robot.connect()
wrapper = ThreadSafeRobot(robot)
assert wrapper.name == "mock_robot"
assert "motor_1.pos" in wrapper.observation_features
assert "motor_1.pos" in wrapper.action_features
assert wrapper.is_connected is True
assert wrapper.inner is robot
robot.disconnect()
# ---------------------------------------------------------------------------
# Strategy factory
# ---------------------------------------------------------------------------
def test_create_strategy_dispatches():
from lerobot.rollout import (
BaseStrategy,
BaseStrategyConfig,
DAggerStrategy,
DAggerStrategyConfig,
SentryStrategy,
SentryStrategyConfig,
create_strategy,
)
assert isinstance(create_strategy(BaseStrategyConfig()), BaseStrategy)
assert isinstance(create_strategy(SentryStrategyConfig()), SentryStrategy)
assert isinstance(create_strategy(DAggerStrategyConfig()), DAggerStrategy)
def test_create_strategy_unknown_raises():
from lerobot.rollout import create_strategy
cfg = MagicMock()
cfg.type = "bogus"
with pytest.raises(ValueError, match="Unknown strategy type"):
create_strategy(cfg)
# ---------------------------------------------------------------------------
# Inference factory
# ---------------------------------------------------------------------------
def test_create_inference_engine_sync():
from lerobot.rollout import SyncInferenceConfig, SyncInferenceEngine, create_inference_engine
engine = create_inference_engine(
SyncInferenceConfig(),
policy=MagicMock(),
preprocessor=MagicMock(),
postprocessor=MagicMock(),
robot_wrapper=MagicMock(robot_type="mock"),
hw_features={},
dataset_features={},
ordered_action_keys=["k"],
task="test",
fps=30.0,
device="cpu",
)
assert isinstance(engine, SyncInferenceEngine)
# ---------------------------------------------------------------------------
# Pure functions
# ---------------------------------------------------------------------------
def test_estimate_max_episode_seconds_no_video():
from lerobot.rollout.strategies import estimate_max_episode_seconds
assert estimate_max_episode_seconds({}, fps=30.0) == 300.0
def test_estimate_max_episode_seconds_with_video():
from lerobot.rollout.strategies import estimate_max_episode_seconds
features = {"cam": {"dtype": "video", "shape": (480, 640, 3)}}
result = estimate_max_episode_seconds(features, fps=30.0)
assert result > 0
# With a real camera, duration should differ from the fallback
assert result != 300.0
def test_safe_push_to_hub():
from lerobot.rollout.strategies import safe_push_to_hub
ds = MagicMock()
ds.num_episodes = 0
assert safe_push_to_hub(ds) is False
ds.push_to_hub.assert_not_called()
ds.num_episodes = 5
assert safe_push_to_hub(ds, tags=["test"]) is True
ds.push_to_hub.assert_called_once_with(tags=["test"], private=False)
# ---------------------------------------------------------------------------
# DAgger state machine
# ---------------------------------------------------------------------------
def test_dagger_full_transition_cycle():
from lerobot.rollout.strategies import DAggerEvents, DAggerPhase
events = DAggerEvents()
assert events.phase == DAggerPhase.AUTONOMOUS
# AUTONOMOUS -> PAUSED
events.request_transition("pause_resume")
old, new = events.consume_transition()
assert (old, new) == (DAggerPhase.AUTONOMOUS, DAggerPhase.PAUSED)
# PAUSED -> CORRECTING
events.request_transition("correction")
old, new = events.consume_transition()
assert (old, new) == (DAggerPhase.PAUSED, DAggerPhase.CORRECTING)
# CORRECTING -> PAUSED
events.request_transition("correction")
old, new = events.consume_transition()
assert (old, new) == (DAggerPhase.CORRECTING, DAggerPhase.PAUSED)
# PAUSED -> AUTONOMOUS
events.request_transition("pause_resume")
old, new = events.consume_transition()
assert (old, new) == (DAggerPhase.PAUSED, DAggerPhase.AUTONOMOUS)
def test_dagger_invalid_transition_ignored():
from lerobot.rollout.strategies import DAggerEvents, DAggerPhase
events = DAggerEvents()
events.request_transition("correction") # Not valid from AUTONOMOUS
assert events.consume_transition() is None
assert events.phase == DAggerPhase.AUTONOMOUS
def test_dagger_events_reset():
from lerobot.rollout.strategies import DAggerEvents, DAggerPhase
events = DAggerEvents()
events.request_transition("pause_resume")
events.consume_transition() # -> PAUSED
events.upload_requested.set()
events.reset()
assert events.phase == DAggerPhase.AUTONOMOUS
assert not events.upload_requested.is_set()
# ---------------------------------------------------------------------------
# Context dataclass
# ---------------------------------------------------------------------------
def test_rollout_context_fields():
from lerobot.rollout import RolloutContext
field_names = {f.name for f in dataclasses.fields(RolloutContext)}
assert field_names == {"runtime", "hardware", "policy", "processors", "data"}