mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 00:29:52 +00:00
feat(rollout): decouple policy deployment from data recording with new lerobot-rollout CLI (#3413)
* feat(scripts): lerobot-rollout * fix(rollout) require dataset in dagger + use duration too * fix(docs): dagger num_episodes * test(rollout): fix expectations * fix(rollout): features check * fix(rollout): device and task propagation + feature pos + warn fps + move rename_map config * docs(rollout): edit rename_map instructions * chore(rollout): multiple minor improvements * chore(rollout): address coments + minor improvements * fix(rollout): enable default * fix(tests): default value RTCConfig * fix(rollout): robot_observation_processor and notify_observation at policy frequency instead of interpolator rate Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> * fix(rollout): prevent relativeactions with sync inference engine Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> * fix(rollout): rtc reanchor to non normalized state Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> * fix(rollout): fixing the episode length to use hwc (#3469) also reducing default length to 5 minutes * feat(rollout): go back to initial position is now a config * fix(rollout): properly propagating video_files_size_in_mb to lerobot_dataset (#3470) * chore(rollout): note about dagger correction stage * chore(docs): update comments and docstring * fix(test): move rtc relative out of rollout module * fix(rollout): address the review comments --------- Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> Co-authored-by: Maxime Ellerbach <maxime.ellerbach@huggingface.co>
This commit is contained in:
@@ -416,6 +416,18 @@ def test_create_initial_counts_zero(tmp_path):
|
||||
assert dataset.num_frames == 0
|
||||
|
||||
|
||||
def test_create_propagates_video_files_size_in_mb(tmp_path):
|
||||
"""video_files_size_in_mb passed to create() is reflected in the dataset metadata."""
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=DUMMY_REPO_ID,
|
||||
fps=DEFAULT_FPS,
|
||||
features=SIMPLE_FEATURES,
|
||||
root=tmp_path / "ds",
|
||||
video_files_size_in_mb=42.0,
|
||||
)
|
||||
assert dataset.meta.video_files_size_in_mb == 42.0
|
||||
|
||||
|
||||
def test_add_frame_works_in_write_mode(tmp_path):
|
||||
"""add_frame() succeeds on a dataset created via create()."""
|
||||
dataset = LeRobotDataset.create(
|
||||
|
||||
@@ -17,9 +17,9 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.policies.rtc.action_interpolator import ActionInterpolator
|
||||
from lerobot.policies.rtc.action_queue import ActionQueue
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.utils.action_interpolator import ActionInterpolator
|
||||
|
||||
# ====================== Fixtures ======================
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ def test_rtc_config_default_initialization():
|
||||
"""Test RTCConfig initializes with default values."""
|
||||
config = RTCConfig()
|
||||
|
||||
assert config.enabled is False
|
||||
assert config.enabled is True
|
||||
assert config.prefix_attention_schedule == RTCAttentionSchedule.LINEAR
|
||||
assert config.max_guidance_weight == 10.0
|
||||
assert config.execution_horizon == 10
|
||||
|
||||
@@ -22,7 +22,7 @@ from lerobot.configs.types import (
|
||||
PolicyFeature,
|
||||
RTCAttentionSchedule,
|
||||
)
|
||||
from lerobot.processor import TransitionKey, batch_to_transition
|
||||
from lerobot.processor import TransitionKey, batch_to_transition, create_transition
|
||||
from lerobot.processor.normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep
|
||||
from lerobot.processor.relative_action_processor import (
|
||||
AbsoluteActionsProcessorStep,
|
||||
@@ -52,6 +52,9 @@ _rtc_debug_mod = _import_rtc_module("lerobot.policies.rtc.debug_tracker", "debug
|
||||
_rtc_mod = _import_rtc_module("lerobot.policies.rtc.modeling_rtc", "modeling_rtc.py")
|
||||
RTCProcessor = _rtc_mod.RTCProcessor
|
||||
|
||||
_rtc_relative_mod = _import_rtc_module("lerobot.policies.rtc.relative", "relative.py")
|
||||
reanchor_relative_rtc_prefix = _rtc_relative_mod.reanchor_relative_rtc_prefix
|
||||
|
||||
ACTION_DIM = 6
|
||||
CHUNK_SIZE = 50
|
||||
EXECUTION_HORIZON = 10
|
||||
@@ -187,7 +190,7 @@ class TestRTCDenoiseWithRelativeLeftovers:
|
||||
|
||||
|
||||
class TestFullPipelineRelativeRTC:
|
||||
"""End-to-end test of the RTC + relative actions pipeline matching eval_with_real_robot.py flow."""
|
||||
"""End-to-end test of the RTC + relative actions pipeline matching lerobot-rollout flow."""
|
||||
|
||||
def test_preprocessor_caches_state_for_postprocessor(self):
|
||||
"""Preprocessor's relative step should cache state so postprocessor can convert back."""
|
||||
@@ -218,7 +221,9 @@ class TestFullPipelineRelativeRTC:
|
||||
|
||||
def test_roundtrip_with_identity_normalization(self):
|
||||
"""Actions → relative → normalize → [model] → unnormalize → absolute should recover originals.
|
||||
Using mean=0, std=1 normalization (identity)."""
|
||||
|
||||
Using mean=0, std=1 normalization (identity).
|
||||
"""
|
||||
relative_step, normalizer, unnormalizer, absolute_step = _make_relative_pipeline()
|
||||
|
||||
state = torch.randn(1, ACTION_DIM)
|
||||
@@ -240,7 +245,7 @@ class TestFullPipelineRelativeRTC:
|
||||
torch.testing.assert_close(recovered, actions, atol=1e-5, rtol=1e-5)
|
||||
|
||||
def test_eval_loop_simulation(self):
|
||||
"""Simulate the eval_with_real_robot.py loop with relative actions.
|
||||
"""Simulate the lerobot-rollout loop with relative actions.
|
||||
|
||||
Iteration 1: No leftovers → model generates relative actions → store for RTC
|
||||
Iteration 2: Use leftovers as RTC guidance → model generates new relative actions
|
||||
@@ -400,13 +405,113 @@ class TestStateRebasingApproximation:
|
||||
assert error_excluded < 1e-6, f"Excluded joint should have zero error, got {error_excluded}"
|
||||
|
||||
|
||||
class TestRTCReanchoringWithStateNormalizer:
|
||||
"""RTC re-anchoring under non-identity OBS_STATE normalization."""
|
||||
|
||||
@staticmethod
|
||||
def _build_normalizer_with_state_stats():
|
||||
"""Build a relative-action preprocessor with non-trivial OBS_STATE stats."""
|
||||
features = {
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,)),
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(ACTION_DIM,)),
|
||||
}
|
||||
norm_map = {
|
||||
FeatureType.ACTION: NormalizationMode.MEAN_STD,
|
||||
FeatureType.STATE: NormalizationMode.MEAN_STD,
|
||||
}
|
||||
stats = {
|
||||
ACTION: {
|
||||
"mean": torch.zeros(ACTION_DIM).numpy(),
|
||||
"std": (0.5 * torch.ones(ACTION_DIM)).numpy(),
|
||||
},
|
||||
OBS_STATE: {
|
||||
"mean": (5.0 * torch.ones(ACTION_DIM)).numpy(),
|
||||
"std": (2.0 * torch.ones(ACTION_DIM)).numpy(),
|
||||
},
|
||||
}
|
||||
relative_step = RelativeActionsProcessorStep(enabled=True)
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
return relative_step, normalizer
|
||||
|
||||
def test_reanchor_with_raw_state_matches_normalize_of_absolute_minus_state(self):
|
||||
"""Reanchoring with the raw cached state yields ``normalize(prev_actions_absolute - raw_state)``."""
|
||||
relative_step, normalizer = self._build_normalizer_with_state_stats()
|
||||
|
||||
raw_state = torch.tensor([[7.0, 8.0, 9.0, 10.0, 11.0, 12.0]])
|
||||
relative_step(batch_to_transition({OBS_STATE: raw_state.clone()}))
|
||||
|
||||
prev_actions_absolute = torch.tensor([[2.0, 3.0, 4.0, 5.0, 6.0, 7.0]] * 5)
|
||||
|
||||
result = reanchor_relative_rtc_prefix(
|
||||
prev_actions_absolute=prev_actions_absolute,
|
||||
current_state=relative_step.get_cached_state(),
|
||||
relative_step=relative_step,
|
||||
normalizer_step=normalizer,
|
||||
policy_device="cpu",
|
||||
)
|
||||
|
||||
expected_relative = to_relative_actions(prev_actions_absolute, raw_state, [True] * ACTION_DIM)
|
||||
expected = normalizer(create_transition(action=expected_relative))[TransitionKey.ACTION]
|
||||
torch.testing.assert_close(result, expected, atol=1e-5, rtol=1e-5)
|
||||
|
||||
def test_reanchor_with_normalized_state_produces_wrong_result(self):
|
||||
"""Reanchoring with raw vs. normalized state produces meaningfully different outputs."""
|
||||
relative_step, normalizer = self._build_normalizer_with_state_stats()
|
||||
|
||||
raw_state = torch.tensor([[7.0, 8.0, 9.0, 10.0, 11.0, 12.0]])
|
||||
relative_step(batch_to_transition({OBS_STATE: raw_state.clone()}))
|
||||
|
||||
normalized_obs = normalizer(batch_to_transition({OBS_STATE: raw_state.clone()}))
|
||||
normalized_state = normalized_obs[TransitionKey.OBSERVATION][OBS_STATE]
|
||||
assert not torch.allclose(normalized_state, raw_state)
|
||||
|
||||
prev_actions_absolute = torch.tensor([[2.0, 3.0, 4.0, 5.0, 6.0, 7.0]] * 5)
|
||||
|
||||
result_raw = reanchor_relative_rtc_prefix(
|
||||
prev_actions_absolute=prev_actions_absolute,
|
||||
current_state=raw_state,
|
||||
relative_step=relative_step,
|
||||
normalizer_step=normalizer,
|
||||
policy_device="cpu",
|
||||
)
|
||||
result_normalized = reanchor_relative_rtc_prefix(
|
||||
prev_actions_absolute=prev_actions_absolute,
|
||||
current_state=normalized_state,
|
||||
relative_step=relative_step,
|
||||
normalizer_step=normalizer,
|
||||
policy_device="cpu",
|
||||
)
|
||||
|
||||
max_abs_diff = (result_raw - result_normalized).abs().max()
|
||||
assert max_abs_diff > 0.5, (
|
||||
f"Raw and normalized state produced near-identical outputs (max diff {max_abs_diff:.4f}); "
|
||||
"OBS_STATE stats are too close to identity to be sensitive."
|
||||
)
|
||||
|
||||
def test_engine_pipeline_cached_state_is_raw_after_full_preprocess(self):
|
||||
"""``get_cached_state()`` returns raw OBS_STATE after the full preprocessor pipeline runs."""
|
||||
relative_step, normalizer = self._build_normalizer_with_state_stats()
|
||||
|
||||
raw_state = torch.tensor([[7.0, 8.0, 9.0, 10.0, 11.0, 12.0]])
|
||||
|
||||
transition = batch_to_transition({OBS_STATE: raw_state.clone()})
|
||||
transition = relative_step(transition)
|
||||
preprocessed = normalizer(transition)
|
||||
|
||||
cached = relative_step.get_cached_state()
|
||||
torch.testing.assert_close(cached, raw_state, atol=1e-6, rtol=1e-6)
|
||||
|
||||
post_normalize_state = preprocessed[TransitionKey.OBSERVATION][OBS_STATE]
|
||||
assert not torch.allclose(cached, post_normalize_state, atol=1e-3)
|
||||
|
||||
|
||||
def _detect_relative_actions(preprocessor) -> bool:
|
||||
"""Mirror of the helper in eval_with_real_robot.py for testing without importing it."""
|
||||
"""Mirror of the helper in lerobot-rollout for testing without importing it."""
|
||||
return any(isinstance(step, RelativeActionsProcessorStep) and step.enabled for step in preprocessor.steps)
|
||||
|
||||
|
||||
class TestDetectRelativeActions:
|
||||
"""Test the _detect_relative_actions helper logic used by eval_with_real_robot.py."""
|
||||
"""Test the _detect_relative_actions helper logic used by lerobot-rollout."""
|
||||
|
||||
def test_detects_enabled_relative_step(self):
|
||||
class FakePipeline:
|
||||
|
||||
@@ -24,10 +24,6 @@ def lerobot_train(args):
|
||||
return run_command(cmd="lerobot-train", module="lerobot_train", args=args)
|
||||
|
||||
|
||||
def lerobot_record(args):
|
||||
return run_command(cmd="lerobot-record", module="lerobot_record", args=args)
|
||||
|
||||
|
||||
def resolve_model_id_for_peft_training(policy_type):
|
||||
"""PEFT training needs pretrained models, this finds the pretrained model of a policy type for PEFT training."""
|
||||
if policy_type == "smolvla":
|
||||
@@ -155,81 +151,3 @@ def test_peft_training_params_are_fewer(policy_type, tmp_path):
|
||||
f"--output_dir={output_dir}",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class DummyRobot:
|
||||
name = "dummy"
|
||||
cameras = []
|
||||
action_features = {"foo": 1.0, "bar": 2.0}
|
||||
observation_features = {"obs1": 1.0, "obs2": 2.0}
|
||||
is_connected = True
|
||||
|
||||
def connect(self, *args):
|
||||
pass
|
||||
|
||||
def disconnect(self):
|
||||
pass
|
||||
|
||||
|
||||
def dummy_make_robot_from_config(*args, **kwargs):
|
||||
return DummyRobot()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("policy_type", ["smolvla"])
|
||||
@skip_if_package_missing("peft")
|
||||
def test_peft_record_loads_policy(policy_type, tmp_path):
|
||||
"""Train a policy with PEFT and attempt to load it with `lerobot-record`."""
|
||||
from peft import PeftModel
|
||||
|
||||
output_dir = tmp_path / f"output_{policy_type}"
|
||||
model_id = resolve_model_id_for_peft_training(policy_type)
|
||||
|
||||
lerobot_train(
|
||||
[
|
||||
f"--policy.path={model_id}",
|
||||
"--policy.push_to_hub=false",
|
||||
"--policy.input_features=null",
|
||||
"--policy.output_features=null",
|
||||
"--peft.method=LORA",
|
||||
"--dataset.repo_id=lerobot/pusht",
|
||||
"--dataset.episodes=[0, 1]",
|
||||
"--steps=1",
|
||||
f"--output_dir={output_dir}",
|
||||
]
|
||||
)
|
||||
|
||||
policy_dir = output_dir / "checkpoints" / "last" / "pretrained_model"
|
||||
dataset_dir = tmp_path / "eval_pusht"
|
||||
single_task = "move the table"
|
||||
loaded_policy = None
|
||||
|
||||
def dummy_record_loop(*args, **kwargs):
|
||||
nonlocal loaded_policy
|
||||
|
||||
if "dataset" not in kwargs:
|
||||
return
|
||||
|
||||
dataset = kwargs["dataset"]
|
||||
dataset.add_frame({"task": single_task})
|
||||
loaded_policy = kwargs["policy"]
|
||||
|
||||
with (
|
||||
patch("lerobot.scripts.lerobot_record.make_robot_from_config", dummy_make_robot_from_config),
|
||||
# disable record loop since we're only interested in successful loading of the policy.
|
||||
patch("lerobot.scripts.lerobot_record.record_loop", dummy_record_loop),
|
||||
# disable speech output
|
||||
patch("lerobot.utils.utils.say"),
|
||||
):
|
||||
lerobot_record(
|
||||
[
|
||||
f"--policy.path={policy_dir}",
|
||||
"--robot.type=so101_follower",
|
||||
"--robot.port=/dev/null",
|
||||
"--dataset.repo_id=lerobot/eval_pusht",
|
||||
f'--dataset.single_task="{single_task}"',
|
||||
f"--dataset.root={dataset_dir}",
|
||||
"--dataset.push_to_hub=false",
|
||||
]
|
||||
)
|
||||
|
||||
assert isinstance(loaded_policy, PeftModel)
|
||||
|
||||
@@ -21,8 +21,9 @@ import pytest
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
pytest.importorskip("deepdiff", reason="deepdiff is required (install lerobot[hardware])")
|
||||
|
||||
from lerobot.configs.dataset import DatasetRecordConfig
|
||||
from lerobot.scripts.lerobot_calibrate import CalibrateConfig, calibrate
|
||||
from lerobot.scripts.lerobot_record import DatasetRecordConfig, RecordConfig, record
|
||||
from lerobot.scripts.lerobot_record import RecordConfig, record
|
||||
from lerobot.scripts.lerobot_replay import DatasetReplayConfig, ReplayConfig, replay
|
||||
from lerobot.scripts.lerobot_teleoperate import TeleoperateConfig, teleoperate
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
|
||||
@@ -0,0 +1,345 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Minimal tests for the rollout module's public API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Import smoke tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_rollout_top_level_imports():
|
||||
import lerobot.rollout
|
||||
|
||||
for name in lerobot.rollout.__all__:
|
||||
assert hasattr(lerobot.rollout, name), f"Missing export: {name}"
|
||||
|
||||
|
||||
def test_inference_submodule_imports():
|
||||
import lerobot.rollout.inference
|
||||
|
||||
for name in lerobot.rollout.inference.__all__:
|
||||
assert hasattr(lerobot.rollout.inference, name), f"Missing export: {name}"
|
||||
|
||||
|
||||
def test_strategies_submodule_imports():
|
||||
import lerobot.rollout.strategies
|
||||
|
||||
for name in lerobot.rollout.strategies.__all__:
|
||||
assert hasattr(lerobot.rollout.strategies, name), f"Missing export: {name}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_strategy_config_types():
|
||||
from lerobot.rollout import (
|
||||
BaseStrategyConfig,
|
||||
DAggerStrategyConfig,
|
||||
HighlightStrategyConfig,
|
||||
SentryStrategyConfig,
|
||||
)
|
||||
|
||||
assert BaseStrategyConfig().type == "base"
|
||||
assert SentryStrategyConfig().type == "sentry"
|
||||
assert HighlightStrategyConfig().type == "highlight"
|
||||
assert DAggerStrategyConfig().type == "dagger"
|
||||
|
||||
|
||||
def test_dagger_config_invalid_input_device():
|
||||
from lerobot.rollout import DAggerStrategyConfig
|
||||
|
||||
with pytest.raises(ValueError, match="input_device must be 'keyboard' or 'pedal'"):
|
||||
DAggerStrategyConfig(input_device="joystick")
|
||||
|
||||
|
||||
def test_dagger_config_defaults():
|
||||
from lerobot.rollout import DAggerStrategyConfig
|
||||
|
||||
cfg = DAggerStrategyConfig()
|
||||
assert cfg.num_episodes is None
|
||||
assert cfg.record_autonomous is False
|
||||
assert cfg.input_device == "keyboard"
|
||||
|
||||
|
||||
def test_inference_config_types():
|
||||
from lerobot.rollout import RTCInferenceConfig, SyncInferenceConfig
|
||||
|
||||
assert SyncInferenceConfig().type == "sync"
|
||||
|
||||
rtc = RTCInferenceConfig()
|
||||
assert rtc.type == "rtc"
|
||||
assert rtc.queue_threshold == 30
|
||||
assert rtc.rtc is not None
|
||||
|
||||
|
||||
def test_sentry_config_defaults():
|
||||
from lerobot.rollout import SentryStrategyConfig
|
||||
|
||||
cfg = SentryStrategyConfig()
|
||||
assert cfg.upload_every_n_episodes == 5
|
||||
assert cfg.target_video_file_size_mb is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RolloutRingBuffer
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_ring_buffer_append_and_eviction():
|
||||
from lerobot.rollout.ring_buffer import RolloutRingBuffer
|
||||
|
||||
buf = RolloutRingBuffer(max_seconds=0.5, max_memory_mb=100.0, fps=10.0)
|
||||
# max_frames = 5
|
||||
for i in range(8):
|
||||
buf.append({"val": i})
|
||||
assert len(buf) == 5
|
||||
|
||||
|
||||
def test_ring_buffer_drain():
|
||||
from lerobot.rollout.ring_buffer import RolloutRingBuffer
|
||||
|
||||
buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0)
|
||||
for i in range(3):
|
||||
buf.append({"val": i})
|
||||
frames = buf.drain()
|
||||
assert len(frames) == 3
|
||||
assert len(buf) == 0
|
||||
assert buf.estimated_bytes == 0
|
||||
|
||||
|
||||
def test_ring_buffer_clear():
|
||||
from lerobot.rollout.ring_buffer import RolloutRingBuffer
|
||||
|
||||
buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0)
|
||||
buf.append({"val": 1})
|
||||
buf.clear()
|
||||
assert len(buf) == 0
|
||||
assert buf.estimated_bytes == 0
|
||||
|
||||
|
||||
def test_ring_buffer_tensor_bytes():
|
||||
from lerobot.rollout.ring_buffer import RolloutRingBuffer
|
||||
|
||||
buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0)
|
||||
t = torch.zeros(100, dtype=torch.float32) # 400 bytes
|
||||
buf.append({"tensor": t})
|
||||
assert buf.estimated_bytes >= 400
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ThreadSafeRobot
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_thread_safe_robot_delegates():
|
||||
from lerobot.rollout.robot_wrapper import ThreadSafeRobot
|
||||
from tests.mocks.mock_robot import MockRobot, MockRobotConfig
|
||||
|
||||
robot = MockRobot(MockRobotConfig(n_motors=3))
|
||||
robot.connect()
|
||||
wrapper = ThreadSafeRobot(robot)
|
||||
|
||||
obs = wrapper.get_observation()
|
||||
assert "motor_1.pos" in obs
|
||||
assert "motor_2.pos" in obs
|
||||
assert "motor_3.pos" in obs
|
||||
|
||||
action = {"motor_1.pos": 0.0, "motor_2.pos": 1.0, "motor_3.pos": 2.0}
|
||||
result = wrapper.send_action(action)
|
||||
assert result == action
|
||||
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
def test_thread_safe_robot_properties():
|
||||
from lerobot.rollout.robot_wrapper import ThreadSafeRobot
|
||||
from tests.mocks.mock_robot import MockRobot, MockRobotConfig
|
||||
|
||||
robot = MockRobot(MockRobotConfig(n_motors=3))
|
||||
robot.connect()
|
||||
wrapper = ThreadSafeRobot(robot)
|
||||
|
||||
assert wrapper.name == "mock_robot"
|
||||
assert "motor_1.pos" in wrapper.observation_features
|
||||
assert "motor_1.pos" in wrapper.action_features
|
||||
assert wrapper.is_connected is True
|
||||
assert wrapper.inner is robot
|
||||
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Strategy factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_create_strategy_dispatches():
|
||||
from lerobot.rollout import (
|
||||
BaseStrategy,
|
||||
BaseStrategyConfig,
|
||||
DAggerStrategy,
|
||||
DAggerStrategyConfig,
|
||||
SentryStrategy,
|
||||
SentryStrategyConfig,
|
||||
create_strategy,
|
||||
)
|
||||
|
||||
assert isinstance(create_strategy(BaseStrategyConfig()), BaseStrategy)
|
||||
assert isinstance(create_strategy(SentryStrategyConfig()), SentryStrategy)
|
||||
assert isinstance(create_strategy(DAggerStrategyConfig()), DAggerStrategy)
|
||||
|
||||
|
||||
def test_create_strategy_unknown_raises():
|
||||
from lerobot.rollout import create_strategy
|
||||
|
||||
cfg = MagicMock()
|
||||
cfg.type = "bogus"
|
||||
with pytest.raises(ValueError, match="Unknown strategy type"):
|
||||
create_strategy(cfg)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Inference factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_create_inference_engine_sync():
|
||||
from lerobot.rollout import SyncInferenceConfig, SyncInferenceEngine, create_inference_engine
|
||||
|
||||
engine = create_inference_engine(
|
||||
SyncInferenceConfig(),
|
||||
policy=MagicMock(),
|
||||
preprocessor=MagicMock(),
|
||||
postprocessor=MagicMock(),
|
||||
robot_wrapper=MagicMock(robot_type="mock"),
|
||||
hw_features={},
|
||||
dataset_features={},
|
||||
ordered_action_keys=["k"],
|
||||
task="test",
|
||||
fps=30.0,
|
||||
device="cpu",
|
||||
)
|
||||
assert isinstance(engine, SyncInferenceEngine)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pure functions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_estimate_max_episode_seconds_no_video():
|
||||
from lerobot.rollout.strategies import estimate_max_episode_seconds
|
||||
|
||||
assert estimate_max_episode_seconds({}, fps=30.0) == 300.0
|
||||
|
||||
|
||||
def test_estimate_max_episode_seconds_with_video():
|
||||
from lerobot.rollout.strategies import estimate_max_episode_seconds
|
||||
|
||||
features = {"cam": {"dtype": "video", "shape": (480, 640, 3)}}
|
||||
result = estimate_max_episode_seconds(features, fps=30.0)
|
||||
assert result > 0
|
||||
# With a real camera, duration should differ from the fallback
|
||||
assert result != 300.0
|
||||
|
||||
|
||||
def test_safe_push_to_hub():
|
||||
from lerobot.rollout.strategies import safe_push_to_hub
|
||||
|
||||
ds = MagicMock()
|
||||
ds.num_episodes = 0
|
||||
assert safe_push_to_hub(ds) is False
|
||||
ds.push_to_hub.assert_not_called()
|
||||
|
||||
ds.num_episodes = 5
|
||||
assert safe_push_to_hub(ds, tags=["test"]) is True
|
||||
ds.push_to_hub.assert_called_once_with(tags=["test"], private=False)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DAgger state machine
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_dagger_full_transition_cycle():
|
||||
from lerobot.rollout.strategies import DAggerEvents, DAggerPhase
|
||||
|
||||
events = DAggerEvents()
|
||||
assert events.phase == DAggerPhase.AUTONOMOUS
|
||||
|
||||
# AUTONOMOUS -> PAUSED
|
||||
events.request_transition("pause_resume")
|
||||
old, new = events.consume_transition()
|
||||
assert (old, new) == (DAggerPhase.AUTONOMOUS, DAggerPhase.PAUSED)
|
||||
|
||||
# PAUSED -> CORRECTING
|
||||
events.request_transition("correction")
|
||||
old, new = events.consume_transition()
|
||||
assert (old, new) == (DAggerPhase.PAUSED, DAggerPhase.CORRECTING)
|
||||
|
||||
# CORRECTING -> PAUSED
|
||||
events.request_transition("correction")
|
||||
old, new = events.consume_transition()
|
||||
assert (old, new) == (DAggerPhase.CORRECTING, DAggerPhase.PAUSED)
|
||||
|
||||
# PAUSED -> AUTONOMOUS
|
||||
events.request_transition("pause_resume")
|
||||
old, new = events.consume_transition()
|
||||
assert (old, new) == (DAggerPhase.PAUSED, DAggerPhase.AUTONOMOUS)
|
||||
|
||||
|
||||
def test_dagger_invalid_transition_ignored():
|
||||
from lerobot.rollout.strategies import DAggerEvents, DAggerPhase
|
||||
|
||||
events = DAggerEvents()
|
||||
events.request_transition("correction") # Not valid from AUTONOMOUS
|
||||
assert events.consume_transition() is None
|
||||
assert events.phase == DAggerPhase.AUTONOMOUS
|
||||
|
||||
|
||||
def test_dagger_events_reset():
|
||||
from lerobot.rollout.strategies import DAggerEvents, DAggerPhase
|
||||
|
||||
events = DAggerEvents()
|
||||
events.request_transition("pause_resume")
|
||||
events.consume_transition() # -> PAUSED
|
||||
events.upload_requested.set()
|
||||
events.reset()
|
||||
assert events.phase == DAggerPhase.AUTONOMOUS
|
||||
assert not events.upload_requested.is_set()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Context dataclass
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_rollout_context_fields():
|
||||
from lerobot.rollout import RolloutContext
|
||||
|
||||
field_names = {f.name for f in dataclasses.fields(RolloutContext)}
|
||||
assert field_names == {"runtime", "hardware", "policy", "processors", "data"}
|
||||
@@ -24,7 +24,7 @@ import pytest
|
||||
|
||||
pytest.importorskip("grpc")
|
||||
|
||||
from lerobot.rl.process import ProcessSignalHandler # noqa: E402
|
||||
from lerobot.utils.process import ProcessSignalHandler # noqa: E402
|
||||
|
||||
|
||||
# Fixture to reset shutdown_event_counter and original signal handlers before and after each test
|
||||
|
||||
Reference in New Issue
Block a user