feat(scripts): lerobot-rollout

This commit is contained in:
Steven Palma
2026-04-14 15:42:04 +02:00
committed by Steven Palma
parent 5c43fa1cce
commit bc06cb44ca
54 changed files with 5204 additions and 2816 deletions
@@ -17,9 +17,9 @@
import pytest
import torch
from lerobot.policies.rtc.action_interpolator import ActionInterpolator
from lerobot.policies.rtc.action_queue import ActionQueue
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.utils.action_interpolator import ActionInterpolator
# ====================== Fixtures ======================
-82
View File
@@ -24,10 +24,6 @@ def lerobot_train(args):
return run_command(cmd="lerobot-train", module="lerobot_train", args=args)
def lerobot_record(args):
return run_command(cmd="lerobot-record", module="lerobot_record", args=args)
def resolve_model_id_for_peft_training(policy_type):
"""PEFT training needs pretrained models, this finds the pretrained model of a policy type for PEFT training."""
if policy_type == "smolvla":
@@ -155,81 +151,3 @@ def test_peft_training_params_are_fewer(policy_type, tmp_path):
f"--output_dir={output_dir}",
]
)
class DummyRobot:
name = "dummy"
cameras = []
action_features = {"foo": 1.0, "bar": 2.0}
observation_features = {"obs1": 1.0, "obs2": 2.0}
is_connected = True
def connect(self, *args):
pass
def disconnect(self):
pass
def dummy_make_robot_from_config(*args, **kwargs):
return DummyRobot()
@pytest.mark.parametrize("policy_type", ["smolvla"])
@skip_if_package_missing("peft")
def test_peft_record_loads_policy(policy_type, tmp_path):
"""Train a policy with PEFT and attempt to load it with `lerobot-record`."""
from peft import PeftModel
output_dir = tmp_path / f"output_{policy_type}"
model_id = resolve_model_id_for_peft_training(policy_type)
lerobot_train(
[
f"--policy.path={model_id}",
"--policy.push_to_hub=false",
"--policy.input_features=null",
"--policy.output_features=null",
"--peft.method=LORA",
"--dataset.repo_id=lerobot/pusht",
"--dataset.episodes=[0, 1]",
"--steps=1",
f"--output_dir={output_dir}",
]
)
policy_dir = output_dir / "checkpoints" / "last" / "pretrained_model"
dataset_dir = tmp_path / "eval_pusht"
single_task = "move the table"
loaded_policy = None
def dummy_record_loop(*args, **kwargs):
nonlocal loaded_policy
if "dataset" not in kwargs:
return
dataset = kwargs["dataset"]
dataset.add_frame({"task": single_task})
loaded_policy = kwargs["policy"]
with (
patch("lerobot.scripts.lerobot_record.make_robot_from_config", dummy_make_robot_from_config),
# disable record loop since we're only interested in successful loading of the policy.
patch("lerobot.scripts.lerobot_record.record_loop", dummy_record_loop),
# disable speech output
patch("lerobot.utils.utils.say"),
):
lerobot_record(
[
f"--policy.path={policy_dir}",
"--robot.type=so101_follower",
"--robot.port=/dev/null",
"--dataset.repo_id=lerobot/eval_pusht",
f'--dataset.single_task="{single_task}"',
f"--dataset.root={dataset_dir}",
"--dataset.push_to_hub=false",
]
)
assert isinstance(loaded_policy, PeftModel)
+2 -1
View File
@@ -21,8 +21,9 @@ import pytest
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
pytest.importorskip("deepdiff", reason="deepdiff is required (install lerobot[hardware])")
from lerobot.configs.dataset import DatasetRecordConfig
from lerobot.scripts.lerobot_calibrate import CalibrateConfig, calibrate
from lerobot.scripts.lerobot_record import DatasetRecordConfig, RecordConfig, record
from lerobot.scripts.lerobot_record import RecordConfig, record
from lerobot.scripts.lerobot_replay import DatasetReplayConfig, ReplayConfig, replay
from lerobot.scripts.lerobot_teleoperate import TeleoperateConfig, teleoperate
from tests.fixtures.constants import DUMMY_REPO_ID
+338
View File
@@ -0,0 +1,338 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Minimal tests for the rollout module's public API."""
from __future__ import annotations
import dataclasses
from unittest.mock import MagicMock
import pytest
import torch
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
# ---------------------------------------------------------------------------
# Import smoke tests
# ---------------------------------------------------------------------------
def test_rollout_top_level_imports():
import lerobot.rollout
for name in lerobot.rollout.__all__:
assert hasattr(lerobot.rollout, name), f"Missing export: {name}"
def test_inference_submodule_imports():
import lerobot.rollout.inference
for name in lerobot.rollout.inference.__all__:
assert hasattr(lerobot.rollout.inference, name), f"Missing export: {name}"
def test_strategies_submodule_imports():
import lerobot.rollout.strategies
for name in lerobot.rollout.strategies.__all__:
assert hasattr(lerobot.rollout.strategies, name), f"Missing export: {name}"
# ---------------------------------------------------------------------------
# Config tests
# ---------------------------------------------------------------------------
def test_strategy_config_types():
from lerobot.rollout import (
BaseStrategyConfig,
DAggerStrategyConfig,
HighlightStrategyConfig,
SentryStrategyConfig,
)
assert BaseStrategyConfig().type == "base"
assert SentryStrategyConfig().type == "sentry"
assert HighlightStrategyConfig().type == "highlight"
assert DAggerStrategyConfig().type == "dagger"
def test_dagger_config_invalid_input_device():
from lerobot.rollout import DAggerStrategyConfig
with pytest.raises(ValueError, match="input_device must be 'keyboard' or 'pedal'"):
DAggerStrategyConfig(input_device="joystick")
def test_dagger_config_defaults():
from lerobot.rollout import DAggerStrategyConfig
cfg = DAggerStrategyConfig()
assert cfg.num_episodes == 10
assert cfg.record_autonomous is False
assert cfg.input_device == "keyboard"
def test_inference_config_types():
from lerobot.rollout.inference import RTCInferenceConfig, SyncInferenceConfig
assert SyncInferenceConfig().type == "sync"
rtc = RTCInferenceConfig()
assert rtc.type == "rtc"
assert rtc.queue_threshold == 30
assert rtc.rtc is not None
def test_sentry_config_defaults():
from lerobot.rollout import SentryStrategyConfig
cfg = SentryStrategyConfig()
assert cfg.upload_every_n_episodes == 5
assert cfg.target_video_file_size_mb is None
# ---------------------------------------------------------------------------
# RolloutRingBuffer
# ---------------------------------------------------------------------------
def test_ring_buffer_append_and_eviction():
from lerobot.rollout import RolloutRingBuffer
buf = RolloutRingBuffer(max_seconds=0.5, max_memory_mb=100.0, fps=10.0)
# max_frames = 5
for i in range(8):
buf.append({"val": i})
assert len(buf) == 5
def test_ring_buffer_drain():
from lerobot.rollout import RolloutRingBuffer
buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0)
for i in range(3):
buf.append({"val": i})
frames = buf.drain()
assert len(frames) == 3
assert len(buf) == 0
assert buf.estimated_bytes == 0
def test_ring_buffer_clear():
from lerobot.rollout import RolloutRingBuffer
buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0)
buf.append({"val": 1})
buf.clear()
assert len(buf) == 0
assert buf.estimated_bytes == 0
def test_ring_buffer_tensor_bytes():
from lerobot.rollout import RolloutRingBuffer
buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0)
t = torch.zeros(100, dtype=torch.float32) # 400 bytes
buf.append({"tensor": t})
assert buf.estimated_bytes >= 400
# ---------------------------------------------------------------------------
# ThreadSafeRobot
# ---------------------------------------------------------------------------
def test_thread_safe_robot_delegates():
from lerobot.rollout import ThreadSafeRobot
from tests.mocks.mock_robot import MockRobot, MockRobotConfig
robot = MockRobot(MockRobotConfig(n_motors=3))
robot.connect()
wrapper = ThreadSafeRobot(robot)
obs = wrapper.get_observation()
assert "motor_1.pos" in obs
assert "motor_2.pos" in obs
assert "motor_3.pos" in obs
action = {"motor_1.pos": 0.0, "motor_2.pos": 1.0, "motor_3.pos": 2.0}
result = wrapper.send_action(action)
assert result == action
robot.disconnect()
def test_thread_safe_robot_properties():
from lerobot.rollout import ThreadSafeRobot
from tests.mocks.mock_robot import MockRobot, MockRobotConfig
robot = MockRobot(MockRobotConfig(n_motors=3))
robot.connect()
wrapper = ThreadSafeRobot(robot)
assert wrapper.name == "mock_robot"
assert "motor_1.pos" in wrapper.observation_features
assert "motor_1.pos" in wrapper.action_features
assert wrapper.is_connected is True
assert wrapper.inner is robot
robot.disconnect()
# ---------------------------------------------------------------------------
# Strategy factory
# ---------------------------------------------------------------------------
def test_create_strategy_dispatches():
from lerobot.rollout import BaseStrategyConfig, DAggerStrategyConfig, SentryStrategyConfig
from lerobot.rollout.strategies import BaseStrategy, DAggerStrategy, SentryStrategy, create_strategy
assert isinstance(create_strategy(BaseStrategyConfig()), BaseStrategy)
assert isinstance(create_strategy(SentryStrategyConfig()), SentryStrategy)
assert isinstance(create_strategy(DAggerStrategyConfig()), DAggerStrategy)
def test_create_strategy_unknown_raises():
from lerobot.rollout.strategies import create_strategy
cfg = MagicMock()
cfg.type = "bogus"
with pytest.raises(ValueError, match="Unknown strategy type"):
create_strategy(cfg)
# ---------------------------------------------------------------------------
# Inference factory
# ---------------------------------------------------------------------------
def test_create_inference_engine_sync():
from lerobot.rollout.inference import SyncInferenceConfig, SyncInferenceEngine, create_inference_engine
engine = create_inference_engine(
SyncInferenceConfig(),
policy=MagicMock(),
preprocessor=MagicMock(),
postprocessor=MagicMock(),
robot_wrapper=MagicMock(robot_type="mock"),
hw_features={},
dataset_features={},
ordered_action_keys=["k"],
task="test",
fps=30.0,
device="cpu",
)
assert isinstance(engine, SyncInferenceEngine)
# ---------------------------------------------------------------------------
# Pure functions
# ---------------------------------------------------------------------------
def test_estimate_max_episode_seconds_no_video():
from lerobot.rollout.strategies import estimate_max_episode_seconds
assert estimate_max_episode_seconds({}, fps=30.0) == 600.0
def test_estimate_max_episode_seconds_with_video():
from lerobot.rollout.strategies import estimate_max_episode_seconds
features = {"cam": {"dtype": "video", "shape": (3, 480, 640)}}
result = estimate_max_episode_seconds(features, fps=30.0)
assert result > 0
# With a real camera, duration should differ from the fallback
assert result != 600.0
def test_safe_push_to_hub():
from lerobot.rollout.strategies import safe_push_to_hub
ds = MagicMock()
ds.num_episodes = 0
assert safe_push_to_hub(ds) is False
ds.push_to_hub.assert_not_called()
ds.num_episodes = 5
assert safe_push_to_hub(ds, tags=["test"]) is True
ds.push_to_hub.assert_called_once_with(tags=["test"], private=False)
# ---------------------------------------------------------------------------
# DAgger state machine
# ---------------------------------------------------------------------------
def test_dagger_full_transition_cycle():
from lerobot.rollout.strategies import DAggerEvents, DAggerPhase
events = DAggerEvents()
assert events.phase == DAggerPhase.AUTONOMOUS
# AUTONOMOUS -> PAUSED
events.request_transition("pause_resume")
old, new = events.consume_transition()
assert (old, new) == (DAggerPhase.AUTONOMOUS, DAggerPhase.PAUSED)
# PAUSED -> CORRECTING
events.request_transition("correction")
old, new = events.consume_transition()
assert (old, new) == (DAggerPhase.PAUSED, DAggerPhase.CORRECTING)
# CORRECTING -> PAUSED
events.request_transition("correction")
old, new = events.consume_transition()
assert (old, new) == (DAggerPhase.CORRECTING, DAggerPhase.PAUSED)
# PAUSED -> AUTONOMOUS
events.request_transition("pause_resume")
old, new = events.consume_transition()
assert (old, new) == (DAggerPhase.PAUSED, DAggerPhase.AUTONOMOUS)
def test_dagger_invalid_transition_ignored():
from lerobot.rollout.strategies import DAggerEvents, DAggerPhase
events = DAggerEvents()
events.request_transition("correction") # Not valid from AUTONOMOUS
assert events.consume_transition() is None
assert events.phase == DAggerPhase.AUTONOMOUS
def test_dagger_events_reset():
from lerobot.rollout.strategies import DAggerEvents, DAggerPhase
events = DAggerEvents()
events.request_transition("pause_resume")
events.consume_transition() # -> PAUSED
events.upload_requested.set()
events.reset()
assert events.phase == DAggerPhase.AUTONOMOUS
assert not events.upload_requested.is_set()
# ---------------------------------------------------------------------------
# Context dataclass
# ---------------------------------------------------------------------------
def test_rollout_context_fields():
from lerobot.rollout import RolloutContext
field_names = {f.name for f in dataclasses.fields(RolloutContext)}
assert field_names == {"runtime", "hardware", "policy", "processors", "data"}
+1 -1
View File
@@ -24,7 +24,7 @@ import pytest
pytest.importorskip("grpc")
from lerobot.rl.process import ProcessSignalHandler # noqa: E402
from lerobot.utils.process import ProcessSignalHandler # noqa: E402
# Fixture to reset shutdown_event_counter and original signal handlers before and after each test