mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 00:59:46 +00:00
feat(scripts): lerobot-rollout
This commit is contained in:
committed by
Steven Palma
parent
5c43fa1cce
commit
bc06cb44ca
@@ -17,9 +17,9 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.policies.rtc.action_interpolator import ActionInterpolator
|
||||
from lerobot.policies.rtc.action_queue import ActionQueue
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.utils.action_interpolator import ActionInterpolator
|
||||
|
||||
# ====================== Fixtures ======================
|
||||
|
||||
|
||||
@@ -24,10 +24,6 @@ def lerobot_train(args):
|
||||
return run_command(cmd="lerobot-train", module="lerobot_train", args=args)
|
||||
|
||||
|
||||
def lerobot_record(args):
|
||||
return run_command(cmd="lerobot-record", module="lerobot_record", args=args)
|
||||
|
||||
|
||||
def resolve_model_id_for_peft_training(policy_type):
|
||||
"""PEFT training needs pretrained models, this finds the pretrained model of a policy type for PEFT training."""
|
||||
if policy_type == "smolvla":
|
||||
@@ -155,81 +151,3 @@ def test_peft_training_params_are_fewer(policy_type, tmp_path):
|
||||
f"--output_dir={output_dir}",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class DummyRobot:
|
||||
name = "dummy"
|
||||
cameras = []
|
||||
action_features = {"foo": 1.0, "bar": 2.0}
|
||||
observation_features = {"obs1": 1.0, "obs2": 2.0}
|
||||
is_connected = True
|
||||
|
||||
def connect(self, *args):
|
||||
pass
|
||||
|
||||
def disconnect(self):
|
||||
pass
|
||||
|
||||
|
||||
def dummy_make_robot_from_config(*args, **kwargs):
|
||||
return DummyRobot()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("policy_type", ["smolvla"])
|
||||
@skip_if_package_missing("peft")
|
||||
def test_peft_record_loads_policy(policy_type, tmp_path):
|
||||
"""Train a policy with PEFT and attempt to load it with `lerobot-record`."""
|
||||
from peft import PeftModel
|
||||
|
||||
output_dir = tmp_path / f"output_{policy_type}"
|
||||
model_id = resolve_model_id_for_peft_training(policy_type)
|
||||
|
||||
lerobot_train(
|
||||
[
|
||||
f"--policy.path={model_id}",
|
||||
"--policy.push_to_hub=false",
|
||||
"--policy.input_features=null",
|
||||
"--policy.output_features=null",
|
||||
"--peft.method=LORA",
|
||||
"--dataset.repo_id=lerobot/pusht",
|
||||
"--dataset.episodes=[0, 1]",
|
||||
"--steps=1",
|
||||
f"--output_dir={output_dir}",
|
||||
]
|
||||
)
|
||||
|
||||
policy_dir = output_dir / "checkpoints" / "last" / "pretrained_model"
|
||||
dataset_dir = tmp_path / "eval_pusht"
|
||||
single_task = "move the table"
|
||||
loaded_policy = None
|
||||
|
||||
def dummy_record_loop(*args, **kwargs):
|
||||
nonlocal loaded_policy
|
||||
|
||||
if "dataset" not in kwargs:
|
||||
return
|
||||
|
||||
dataset = kwargs["dataset"]
|
||||
dataset.add_frame({"task": single_task})
|
||||
loaded_policy = kwargs["policy"]
|
||||
|
||||
with (
|
||||
patch("lerobot.scripts.lerobot_record.make_robot_from_config", dummy_make_robot_from_config),
|
||||
# disable record loop since we're only interested in successful loading of the policy.
|
||||
patch("lerobot.scripts.lerobot_record.record_loop", dummy_record_loop),
|
||||
# disable speech output
|
||||
patch("lerobot.utils.utils.say"),
|
||||
):
|
||||
lerobot_record(
|
||||
[
|
||||
f"--policy.path={policy_dir}",
|
||||
"--robot.type=so101_follower",
|
||||
"--robot.port=/dev/null",
|
||||
"--dataset.repo_id=lerobot/eval_pusht",
|
||||
f'--dataset.single_task="{single_task}"',
|
||||
f"--dataset.root={dataset_dir}",
|
||||
"--dataset.push_to_hub=false",
|
||||
]
|
||||
)
|
||||
|
||||
assert isinstance(loaded_policy, PeftModel)
|
||||
|
||||
@@ -21,8 +21,9 @@ import pytest
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
pytest.importorskip("deepdiff", reason="deepdiff is required (install lerobot[hardware])")
|
||||
|
||||
from lerobot.configs.dataset import DatasetRecordConfig
|
||||
from lerobot.scripts.lerobot_calibrate import CalibrateConfig, calibrate
|
||||
from lerobot.scripts.lerobot_record import DatasetRecordConfig, RecordConfig, record
|
||||
from lerobot.scripts.lerobot_record import RecordConfig, record
|
||||
from lerobot.scripts.lerobot_replay import DatasetReplayConfig, ReplayConfig, replay
|
||||
from lerobot.scripts.lerobot_teleoperate import TeleoperateConfig, teleoperate
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
|
||||
@@ -0,0 +1,338 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Minimal tests for the rollout module's public API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Import smoke tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_rollout_top_level_imports():
|
||||
import lerobot.rollout
|
||||
|
||||
for name in lerobot.rollout.__all__:
|
||||
assert hasattr(lerobot.rollout, name), f"Missing export: {name}"
|
||||
|
||||
|
||||
def test_inference_submodule_imports():
|
||||
import lerobot.rollout.inference
|
||||
|
||||
for name in lerobot.rollout.inference.__all__:
|
||||
assert hasattr(lerobot.rollout.inference, name), f"Missing export: {name}"
|
||||
|
||||
|
||||
def test_strategies_submodule_imports():
|
||||
import lerobot.rollout.strategies
|
||||
|
||||
for name in lerobot.rollout.strategies.__all__:
|
||||
assert hasattr(lerobot.rollout.strategies, name), f"Missing export: {name}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_strategy_config_types():
|
||||
from lerobot.rollout import (
|
||||
BaseStrategyConfig,
|
||||
DAggerStrategyConfig,
|
||||
HighlightStrategyConfig,
|
||||
SentryStrategyConfig,
|
||||
)
|
||||
|
||||
assert BaseStrategyConfig().type == "base"
|
||||
assert SentryStrategyConfig().type == "sentry"
|
||||
assert HighlightStrategyConfig().type == "highlight"
|
||||
assert DAggerStrategyConfig().type == "dagger"
|
||||
|
||||
|
||||
def test_dagger_config_invalid_input_device():
|
||||
from lerobot.rollout import DAggerStrategyConfig
|
||||
|
||||
with pytest.raises(ValueError, match="input_device must be 'keyboard' or 'pedal'"):
|
||||
DAggerStrategyConfig(input_device="joystick")
|
||||
|
||||
|
||||
def test_dagger_config_defaults():
|
||||
from lerobot.rollout import DAggerStrategyConfig
|
||||
|
||||
cfg = DAggerStrategyConfig()
|
||||
assert cfg.num_episodes == 10
|
||||
assert cfg.record_autonomous is False
|
||||
assert cfg.input_device == "keyboard"
|
||||
|
||||
|
||||
def test_inference_config_types():
|
||||
from lerobot.rollout.inference import RTCInferenceConfig, SyncInferenceConfig
|
||||
|
||||
assert SyncInferenceConfig().type == "sync"
|
||||
|
||||
rtc = RTCInferenceConfig()
|
||||
assert rtc.type == "rtc"
|
||||
assert rtc.queue_threshold == 30
|
||||
assert rtc.rtc is not None
|
||||
|
||||
|
||||
def test_sentry_config_defaults():
|
||||
from lerobot.rollout import SentryStrategyConfig
|
||||
|
||||
cfg = SentryStrategyConfig()
|
||||
assert cfg.upload_every_n_episodes == 5
|
||||
assert cfg.target_video_file_size_mb is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RolloutRingBuffer
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_ring_buffer_append_and_eviction():
|
||||
from lerobot.rollout import RolloutRingBuffer
|
||||
|
||||
buf = RolloutRingBuffer(max_seconds=0.5, max_memory_mb=100.0, fps=10.0)
|
||||
# max_frames = 5
|
||||
for i in range(8):
|
||||
buf.append({"val": i})
|
||||
assert len(buf) == 5
|
||||
|
||||
|
||||
def test_ring_buffer_drain():
|
||||
from lerobot.rollout import RolloutRingBuffer
|
||||
|
||||
buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0)
|
||||
for i in range(3):
|
||||
buf.append({"val": i})
|
||||
frames = buf.drain()
|
||||
assert len(frames) == 3
|
||||
assert len(buf) == 0
|
||||
assert buf.estimated_bytes == 0
|
||||
|
||||
|
||||
def test_ring_buffer_clear():
|
||||
from lerobot.rollout import RolloutRingBuffer
|
||||
|
||||
buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0)
|
||||
buf.append({"val": 1})
|
||||
buf.clear()
|
||||
assert len(buf) == 0
|
||||
assert buf.estimated_bytes == 0
|
||||
|
||||
|
||||
def test_ring_buffer_tensor_bytes():
|
||||
from lerobot.rollout import RolloutRingBuffer
|
||||
|
||||
buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0)
|
||||
t = torch.zeros(100, dtype=torch.float32) # 400 bytes
|
||||
buf.append({"tensor": t})
|
||||
assert buf.estimated_bytes >= 400
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ThreadSafeRobot
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_thread_safe_robot_delegates():
|
||||
from lerobot.rollout import ThreadSafeRobot
|
||||
from tests.mocks.mock_robot import MockRobot, MockRobotConfig
|
||||
|
||||
robot = MockRobot(MockRobotConfig(n_motors=3))
|
||||
robot.connect()
|
||||
wrapper = ThreadSafeRobot(robot)
|
||||
|
||||
obs = wrapper.get_observation()
|
||||
assert "motor_1.pos" in obs
|
||||
assert "motor_2.pos" in obs
|
||||
assert "motor_3.pos" in obs
|
||||
|
||||
action = {"motor_1.pos": 0.0, "motor_2.pos": 1.0, "motor_3.pos": 2.0}
|
||||
result = wrapper.send_action(action)
|
||||
assert result == action
|
||||
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
def test_thread_safe_robot_properties():
|
||||
from lerobot.rollout import ThreadSafeRobot
|
||||
from tests.mocks.mock_robot import MockRobot, MockRobotConfig
|
||||
|
||||
robot = MockRobot(MockRobotConfig(n_motors=3))
|
||||
robot.connect()
|
||||
wrapper = ThreadSafeRobot(robot)
|
||||
|
||||
assert wrapper.name == "mock_robot"
|
||||
assert "motor_1.pos" in wrapper.observation_features
|
||||
assert "motor_1.pos" in wrapper.action_features
|
||||
assert wrapper.is_connected is True
|
||||
assert wrapper.inner is robot
|
||||
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Strategy factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_create_strategy_dispatches():
|
||||
from lerobot.rollout import BaseStrategyConfig, DAggerStrategyConfig, SentryStrategyConfig
|
||||
from lerobot.rollout.strategies import BaseStrategy, DAggerStrategy, SentryStrategy, create_strategy
|
||||
|
||||
assert isinstance(create_strategy(BaseStrategyConfig()), BaseStrategy)
|
||||
assert isinstance(create_strategy(SentryStrategyConfig()), SentryStrategy)
|
||||
assert isinstance(create_strategy(DAggerStrategyConfig()), DAggerStrategy)
|
||||
|
||||
|
||||
def test_create_strategy_unknown_raises():
|
||||
from lerobot.rollout.strategies import create_strategy
|
||||
|
||||
cfg = MagicMock()
|
||||
cfg.type = "bogus"
|
||||
with pytest.raises(ValueError, match="Unknown strategy type"):
|
||||
create_strategy(cfg)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Inference factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_create_inference_engine_sync():
|
||||
from lerobot.rollout.inference import SyncInferenceConfig, SyncInferenceEngine, create_inference_engine
|
||||
|
||||
engine = create_inference_engine(
|
||||
SyncInferenceConfig(),
|
||||
policy=MagicMock(),
|
||||
preprocessor=MagicMock(),
|
||||
postprocessor=MagicMock(),
|
||||
robot_wrapper=MagicMock(robot_type="mock"),
|
||||
hw_features={},
|
||||
dataset_features={},
|
||||
ordered_action_keys=["k"],
|
||||
task="test",
|
||||
fps=30.0,
|
||||
device="cpu",
|
||||
)
|
||||
assert isinstance(engine, SyncInferenceEngine)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pure functions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_estimate_max_episode_seconds_no_video():
|
||||
from lerobot.rollout.strategies import estimate_max_episode_seconds
|
||||
|
||||
assert estimate_max_episode_seconds({}, fps=30.0) == 600.0
|
||||
|
||||
|
||||
def test_estimate_max_episode_seconds_with_video():
|
||||
from lerobot.rollout.strategies import estimate_max_episode_seconds
|
||||
|
||||
features = {"cam": {"dtype": "video", "shape": (3, 480, 640)}}
|
||||
result = estimate_max_episode_seconds(features, fps=30.0)
|
||||
assert result > 0
|
||||
# With a real camera, duration should differ from the fallback
|
||||
assert result != 600.0
|
||||
|
||||
|
||||
def test_safe_push_to_hub():
|
||||
from lerobot.rollout.strategies import safe_push_to_hub
|
||||
|
||||
ds = MagicMock()
|
||||
ds.num_episodes = 0
|
||||
assert safe_push_to_hub(ds) is False
|
||||
ds.push_to_hub.assert_not_called()
|
||||
|
||||
ds.num_episodes = 5
|
||||
assert safe_push_to_hub(ds, tags=["test"]) is True
|
||||
ds.push_to_hub.assert_called_once_with(tags=["test"], private=False)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DAgger state machine
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_dagger_full_transition_cycle():
|
||||
from lerobot.rollout.strategies import DAggerEvents, DAggerPhase
|
||||
|
||||
events = DAggerEvents()
|
||||
assert events.phase == DAggerPhase.AUTONOMOUS
|
||||
|
||||
# AUTONOMOUS -> PAUSED
|
||||
events.request_transition("pause_resume")
|
||||
old, new = events.consume_transition()
|
||||
assert (old, new) == (DAggerPhase.AUTONOMOUS, DAggerPhase.PAUSED)
|
||||
|
||||
# PAUSED -> CORRECTING
|
||||
events.request_transition("correction")
|
||||
old, new = events.consume_transition()
|
||||
assert (old, new) == (DAggerPhase.PAUSED, DAggerPhase.CORRECTING)
|
||||
|
||||
# CORRECTING -> PAUSED
|
||||
events.request_transition("correction")
|
||||
old, new = events.consume_transition()
|
||||
assert (old, new) == (DAggerPhase.CORRECTING, DAggerPhase.PAUSED)
|
||||
|
||||
# PAUSED -> AUTONOMOUS
|
||||
events.request_transition("pause_resume")
|
||||
old, new = events.consume_transition()
|
||||
assert (old, new) == (DAggerPhase.PAUSED, DAggerPhase.AUTONOMOUS)
|
||||
|
||||
|
||||
def test_dagger_invalid_transition_ignored():
|
||||
from lerobot.rollout.strategies import DAggerEvents, DAggerPhase
|
||||
|
||||
events = DAggerEvents()
|
||||
events.request_transition("correction") # Not valid from AUTONOMOUS
|
||||
assert events.consume_transition() is None
|
||||
assert events.phase == DAggerPhase.AUTONOMOUS
|
||||
|
||||
|
||||
def test_dagger_events_reset():
|
||||
from lerobot.rollout.strategies import DAggerEvents, DAggerPhase
|
||||
|
||||
events = DAggerEvents()
|
||||
events.request_transition("pause_resume")
|
||||
events.consume_transition() # -> PAUSED
|
||||
events.upload_requested.set()
|
||||
events.reset()
|
||||
assert events.phase == DAggerPhase.AUTONOMOUS
|
||||
assert not events.upload_requested.is_set()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Context dataclass
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_rollout_context_fields():
|
||||
from lerobot.rollout import RolloutContext
|
||||
|
||||
field_names = {f.name for f in dataclasses.fields(RolloutContext)}
|
||||
assert field_names == {"runtime", "hardware", "policy", "processors", "data"}
|
||||
@@ -24,7 +24,7 @@ import pytest
|
||||
|
||||
pytest.importorskip("grpc")
|
||||
|
||||
from lerobot.rl.process import ProcessSignalHandler # noqa: E402
|
||||
from lerobot.utils.process import ProcessSignalHandler # noqa: E402
|
||||
|
||||
|
||||
# Fixture to reset shutdown_event_counter and original signal handlers before and after each test
|
||||
|
||||
Reference in New Issue
Block a user