diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 3dcba5993..582d4fc14 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -61,6 +61,8 @@ title: SARM title: "Reward Models" - sections: + - local: inference + title: Policy Deployment (lerobot-rollout) - local: async title: Use Async Inference - local: rtc diff --git a/docs/source/hil_data_collection.mdx b/docs/source/hil_data_collection.mdx index a8c772658..465ef045d 100644 --- a/docs/source/hil_data_collection.mdx +++ b/docs/source/hil_data_collection.mdx @@ -111,8 +111,7 @@ lerobot-rollout --strategy.type=dagger \ --dataset.repo_id=your-username/hil-dataset \ --dataset.single_task="Fold the T-shirt properly" \ --dataset.fps=30 \ - --dataset.episode_time_s=1000 \ - --dataset.num_episodes=50 \ + --strategy.num_episodes=50 \ --interpolation_multiplier=2 ``` @@ -139,8 +138,7 @@ lerobot-rollout --strategy.type=dagger \ --dataset.repo_id=your-username/hil-rtc-dataset \ --dataset.single_task="Fold the T-shirt properly" \ --dataset.fps=30 \ - --dataset.episode_time_s=1000 \ - --dataset.num_episodes=50 \ + --strategy.num_episodes=50 \ --interpolation_multiplier=3 ``` diff --git a/docs/source/il_robots.mdx b/docs/source/il_robots.mdx index 1c6a6c543..ff0a6229e 100644 --- a/docs/source/il_robots.mdx +++ b/docs/source/il_robots.mdx @@ -528,7 +528,6 @@ lerobot-rollout \ ```bash lerobot-rollout \ --strategy.type=sentry \ - --strategy.episode_duration_s=60 \ --strategy.upload_every_n_episodes=5 \ --policy.path=${HF_USER}/my_policy \ --robot.type=so100_follower \ diff --git a/docs/source/inference.mdx b/docs/source/inference.mdx new file mode 100644 index 000000000..4a941fccd --- /dev/null +++ b/docs/source/inference.mdx @@ -0,0 +1,262 @@ +# Policy Deployment (lerobot-rollout) + +`lerobot-rollout` is the single CLI for deploying trained policies on real robots. It supports multiple execution strategies and inference backends, from quick evaluation to continuous recording and human-in-the-loop data collection. + +## Quick Start + +No extra dependencies are needed beyond your robot and policy extras. + +```bash +lerobot-rollout \ + --strategy.type=base \ + --policy.path=lerobot/act_koch_real \ + --robot.type=koch_follower \ + --robot.port=/dev/ttyACM0 \ + --task="pick up cube" \ + --duration=30 +``` + +This runs the policy for 30 seconds with no recording. + +--- + +## Strategies + +Select a strategy with `--strategy.type=`. Each strategy defines a different control loop with its own recording and interaction semantics. + +### Base (`--strategy.type=base`) + +Autonomous policy execution with no data recording. Use this for quick evaluation, demos, or when you only need to observe the robot. + +```bash +lerobot-rollout \ + --strategy.type=base \ + --policy.path=${HF_USER}/my_policy \ + --robot.type=so100_follower \ + --robot.port=/dev/ttyACM0 \ + --robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \ + --task="Put lego brick into the box" \ + --duration=60 +``` + +| Flag | Description | +| ---------------- | ------------------------------------------------------ | +| `--duration` | Run time in seconds (0 = infinite) | +| `--task` | Task description passed to the policy | +| `--display_data` | Stream observations/actions to Rerun for visualization | + +### Sentry (`--strategy.type=sentry`) + +Continuous autonomous recording with periodic upload to the Hugging Face Hub. Episode boundaries are auto-computed from camera resolution and FPS so each saved episode produces a complete video file, keeping uploads efficient. + +Policy state (hidden state, RTC queue) persists across episode boundaries: the robot does not reset between episodes. + +```bash +lerobot-rollout \ + --strategy.type=sentry \ + --strategy.upload_every_n_episodes=5 \ + --policy.path=${HF_USER}/my_policy \ + --robot.type=so100_follower \ + --robot.port=/dev/ttyACM0 \ + --robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \ + --dataset.repo_id=${HF_USER}/eval_data \ + --dataset.single_task="Put lego brick into the box" \ + --duration=3600 +``` + +| Flag | Description | +| -------------------------------------- | ----------------------------------------------------------- | +| `--strategy.upload_every_n_episodes` | Push to Hub every N episodes (default: 5) | +| `--strategy.target_video_file_size_mb` | Target video file size for episode rotation (default: auto) | +| `--dataset.repo_id` | **Required.** Hub repository for the recorded dataset | +| `--dataset.push_to_hub` | Whether to push to Hub on teardown (default: true) | + +### Highlight (`--strategy.type=highlight`) + +Autonomous rollout with on-demand recording via a memory-bounded ring buffer. The robot runs continuously while the buffer captures the last N seconds of telemetry. Press the save key to flush the buffer and start live recording; press it again to save the episode. + +```bash +lerobot-rollout \ + --strategy.type=highlight \ + --strategy.ring_buffer_seconds=30 \ + --strategy.save_key=s \ + --strategy.push_key=h \ + --policy.path=${HF_USER}/my_policy \ + --robot.type=koch_follower \ + --robot.port=/dev/ttyACM0 \ + --dataset.repo_id=${HF_USER}/highlight_data \ + --dataset.single_task="Pick up the red cube" +``` + +**Keyboard controls:** + +| Key | Action | +| ------------------ | -------------------------------------------------------- | +| `s` (configurable) | Start recording (flushes buffer) / stop and save episode | +| `h` (configurable) | Push dataset to Hub | +| `ESC` | Stop the session | + +| Flag | Description | +| -------------------------------------- | ---------------------------------------------- | +| `--strategy.ring_buffer_seconds` | Duration of buffered telemetry (default: 30) | +| `--strategy.ring_buffer_max_memory_mb` | Memory cap for the ring buffer (default: 2048) | +| `--strategy.save_key` | Key to toggle recording (default: `s`) | +| `--strategy.push_key` | Key to push to Hub (default: `h`) | + +### DAgger (`--strategy.type=dagger`) + +Human-in-the-loop data collection. Alternates between autonomous policy execution and human intervention via a teleoperator. Intervention frames are tagged with `intervention=True`. Requires a teleoperator (`--teleop.type`). + +See the [Human-In-the-Loop Data Collection](./hil_data_collection) guide for a detailed walkthrough. + +**Corrections-only mode** (default): Only human correction windows are recorded. Each correction becomes one episode. + +```bash +lerobot-rollout \ + --strategy.type=dagger \ + --strategy.num_episodes=20 \ + --policy.path=outputs/pretrain/checkpoints/last/pretrained_model \ + --robot.type=bi_openarm_follower \ + --teleop.type=openarm_mini \ + --dataset.repo_id=${HF_USER}/hil_data \ + --dataset.single_task="Fold the T-shirt" +``` + +**Continuous recording mode** (`--strategy.record_autonomous=true`): Both autonomous and correction frames are recorded with time-based episode rotation (same as Sentry). + +```bash +lerobot-rollout \ + --strategy.type=dagger \ + --strategy.record_autonomous=true \ + --strategy.num_episodes=50 \ + --policy.path=${HF_USER}/my_policy \ + --robot.type=so100_follower \ + --robot.port=/dev/ttyACM0 \ + --teleop.type=so101_leader \ + --teleop.port=/dev/ttyACM1 \ + --dataset.repo_id=${HF_USER}/dagger_data \ + --dataset.single_task="Grasp the block" +``` + +**Keyboard controls** (default input device): + +| Key | Action | +| ------- | ------------------------------------------- | +| `Space` | Pause / resume policy execution | +| `Tab` | Start / stop human correction | +| `Enter` | Push dataset to Hub (corrections-only mode) | +| `ESC` | Stop the session | + +Foot pedal input is also supported via `--strategy.input_device=pedal`. Configure pedal codes with `--strategy.pedal.*` flags. + +| Flag | Description | +| ------------------------------------ | ------------------------------------------------------- | +| `--strategy.num_episodes` | Number of correction episodes to record (default: 10) | +| `--strategy.record_autonomous` | Record autonomous frames too (default: false) | +| `--strategy.upload_every_n_episodes` | Push to Hub every N episodes (default: 5) | +| `--strategy.input_device` | Input device: `keyboard` or `pedal` (default: keyboard) | +| `--teleop.type` | **Required.** Teleoperator type | + +--- + +## Inference Backends + +Select a backend with `--inference.type=`. All strategies work with both backends. + +### Sync (default) + +One policy call per control tick. The main loop blocks until the action is computed. + +Works with all policies. No extra flags needed. + +### Real-Time Chunking (`--inference.type=rtc`) + +A background thread produces action chunks asynchronously. The main control loop polls for the next ready action while the policy computes the next chunk in parallel. + +Use RTC with large, slow VLA models (Pi0, Pi0.5, SmolVLA) for smooth, continuous motion despite high inference latency. + +```bash +lerobot-rollout \ + --strategy.type=base \ + --inference.type=rtc \ + --inference.rtc.execution_horizon=10 \ + --inference.rtc.max_guidance_weight=10.0 \ + --policy.path=${HF_USER}/pi0_policy \ + --robot.type=so100_follower \ + --robot.port=/dev/ttyACM0 \ + --robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \ + --task="Pick up the cube" \ + --duration=60 \ + --device=cuda +``` + +| Flag | Description | +| ------------------------------------------- | -------------------------------------------------------------- | +| `--inference.rtc.execution_horizon` | Steps to blend with previous chunk (default: varies by policy) | +| `--inference.rtc.max_guidance_weight` | Consistency enforcement strength (default: varies by policy) | +| `--inference.rtc.prefix_attention_schedule` | Blend schedule: `LINEAR`, `EXP`, `ONES`, `ZEROS` | +| `--inference.queue_threshold` | Max queue size before backpressure (default: 30) | + +See the [Real-Time Chunking](./rtc) guide for details on tuning RTC parameters. + +--- + +## Common Flags + +| Flag | Description | Default | +| --------------------------------- | ----------------------------------------------------------------- | ------- | +| `--policy.path` | **Required.** HF Hub model ID or local checkpoint path | -- | +| `--robot.type` | **Required.** Robot type (e.g. `so100_follower`, `koch_follower`) | -- | +| `--robot.port` | Serial port for the robot | -- | +| `--robot.cameras` | Camera configuration (JSON dict) | -- | +| `--fps` | Control loop frequency | 30 | +| `--duration` | Run time in seconds (0 = infinite) | 0 | +| `--device` | Torch device (`cpu`, `cuda`, `mps`) | auto | +| `--task` | Task description (used when no dataset is provided) | -- | +| `--display_data` | Stream telemetry to Rerun visualization | false | +| `--display_ip` / `--display_port` | Remote Rerun server address | -- | +| `--interpolation_multiplier` | Action interpolation factor | 1 | +| `--use_torch_compile` | Enable `torch.compile` for inference | false | +| `--resume` | Resume a previous recording session | false | +| `--play_sounds` | Vocal synthesis for events | true | + +--- + +## Programmatic Usage + +For custom deployments (e.g. with kinematics processors), use the rollout module API directly: + +```python +from lerobot.rollout.configs import BaseStrategyConfig, RolloutConfig +from lerobot.rollout.context import build_rollout_context +from lerobot.rollout.inference import SyncInferenceConfig +from lerobot.rollout.strategies.base import BaseStrategy +from lerobot.utils.process import ProcessSignalHandler + +cfg = RolloutConfig( + robot=my_robot_config, + policy=my_policy_config, + strategy=BaseStrategyConfig(), + inference=SyncInferenceConfig(), + fps=30, + duration=60, + task="my task", +) + +signal_handler = ProcessSignalHandler(use_threads=True) +ctx = build_rollout_context( + cfg, + signal_handler.shutdown_event, + robot_action_processor=my_custom_action_processor, # optional + robot_observation_processor=my_custom_obs_processor, # optional +) + +strategy = BaseStrategy(cfg.strategy) +try: + strategy.setup(ctx) + strategy.run(ctx) +finally: + strategy.teardown(ctx) +``` + +See `examples/so100_to_so100_EE/rollout.py` and `examples/phone_to_so100/rollout.py` for full examples with kinematics processors. diff --git a/src/lerobot/rollout/configs.py b/src/lerobot/rollout/configs.py index 67737dae5..22ed39031 100644 --- a/src/lerobot/rollout/configs.py +++ b/src/lerobot/rollout/configs.py @@ -211,6 +211,7 @@ class RolloutConfig: compile_warmup_inferences: int = 2 def __post_init__(self): + """Validate config invariants and load the policy config from ``--policy.path``.""" # --- Strategy-specific validation --- if isinstance(self.strategy, DAggerStrategyConfig) and self.teleop is None: raise ValueError("DAgger strategy requires --teleop.type to be set") diff --git a/src/lerobot/rollout/inference/rtc.py b/src/lerobot/rollout/inference/rtc.py index f905aee50..ae8719b77 100644 --- a/src/lerobot/rollout/inference/rtc.py +++ b/src/lerobot/rollout/inference/rtc.py @@ -73,7 +73,13 @@ def _reanchor_relative_rtc_prefix( normalizer_step: NormalizerProcessorStep | None, policy_device: torch.device | str, ) -> torch.Tensor: - """Convert absolute leftover actions into model-space for relative-action RTC policies.""" + """Convert absolute leftover actions into model-space for relative-action RTC policies. + + When using relative actions, the RTC prefix (previous chunk's unexecuted tail) + is stored in absolute coordinates. Before feeding it back to the policy, this + helper re-expresses those actions relative to the robot's current joint state + and optionally normalizes them so the policy receives correctly scaled inputs. + """ state = current_state.detach().cpu() if state.dim() == 1: state = state.unsqueeze(0) diff --git a/src/lerobot/rollout/strategies/core.py b/src/lerobot/rollout/strategies/core.py index 6b67e4ad5..f0f146109 100644 --- a/src/lerobot/rollout/strategies/core.py +++ b/src/lerobot/rollout/strategies/core.py @@ -52,10 +52,12 @@ class RolloutStrategy(abc.ABC): self._warmup_flushed: bool = False def _init_engine(self, ctx: RolloutContext) -> None: - """Attach the inference engine + interpolator and start the backend. + """Attach the inference engine and action interpolator, then start the backend. - Call this from ``setup()`` so strategies share identical setup - without duplicating code. + Creates an :class:`ActionInterpolator` from the config's + ``interpolation_multiplier`` and starts the inference engine. + Call this from ``setup()`` so strategies share identical + initialisation without duplicating code. """ self._interpolator = ActionInterpolator(multiplier=ctx.runtime.cfg.interpolation_multiplier) self._engine = ctx.policy.inference diff --git a/src/lerobot/rollout/strategies/dagger.py b/src/lerobot/rollout/strategies/dagger.py index 4546227c6..b346e50b8 100644 --- a/src/lerobot/rollout/strategies/dagger.py +++ b/src/lerobot/rollout/strategies/dagger.py @@ -175,8 +175,8 @@ def _teleop_smooth_move_to( ) -> None: """Smoothly move teleop to target position via linear interpolation. - The teleoperator is guaranteed to have motor control methods - (validated at context build time). + Requires the teleoperator to support motor control methods + (``enable_torque``, ``write_goal_positions``, ``get_action``). """ teleop.enable_torque() current = teleop.get_action() diff --git a/src/lerobot/scripts/lerobot_rollout.py b/src/lerobot/scripts/lerobot_rollout.py index 98f6ba65e..fe221047d 100644 --- a/src/lerobot/scripts/lerobot_rollout.py +++ b/src/lerobot/scripts/lerobot_rollout.py @@ -19,46 +19,107 @@ ``lerobot-rollout`` is the single CLI for running trained policies on real robots. - --strategy.type=base 24/7 autonomous rollout (no recording) +Strategies +---------- + --strategy.type=base Autonomous rollout, no recording --strategy.type=sentry Continuous recording with auto-upload --strategy.type=highlight Ring buffer + keystroke save - --strategy.type=dagger Human-in-the-loop (DAgger/RaC) + --strategy.type=dagger Human-in-the-loop (DAgger / RaC) -Usage examples:: +Inference backends +------------------ + --inference.type=sync One policy call per control tick (default) + --inference.type=rtc Real-Time Chunking for slow VLA models - # Base mode (sync inference) +Usage examples +-------------- +:: + + # Base mode — quick evaluation with sync inference lerobot-rollout \\ --strategy.type=base \\ --policy.path=lerobot/act_koch_real \\ --robot.type=koch_follower \\ + --robot.port=/dev/ttyACM0 \\ --task="pick up cube" --duration=30 - # Base mode (RTC for slow VLAs) + # Base mode — RTC inference for slow VLAs (Pi0, Pi0.5, SmolVLA) lerobot-rollout \\ --strategy.type=base \\ --policy.path=lerobot/pi0_base \\ - --inference.type=rtc --inference.rtc.execution_horizon=10 \\ + --inference.type=rtc \\ + --inference.rtc.execution_horizon=10 \\ + --inference.rtc.max_guidance_weight=10.0 \\ --robot.type=so100_follower \\ + --robot.port=/dev/ttyACM0 \\ + --robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \\ --task="pick up cube" --duration=60 - # Sentry mode (continuous recording) + # Sentry mode — continuous recording with periodic upload lerobot-rollout \\ --strategy.type=sentry \\ --strategy.upload_every_n_episodes=5 \\ --policy.path=lerobot/pi0_base \\ --inference.type=rtc \\ --robot.type=so100_follower \\ + --robot.port=/dev/ttyACM0 \\ --dataset.repo_id=user/sentry-data \\ --dataset.single_task="patrol" --duration=3600 - # DAgger mode (human-in-the-loop) + # Highlight mode — ring buffer, press 's' to save, 'h' to push + lerobot-rollout \\ + --strategy.type=highlight \\ + --strategy.ring_buffer_seconds=30 \\ + --policy.path=lerobot/act_koch_real \\ + --robot.type=koch_follower \\ + --robot.port=/dev/ttyACM0 \\ + --dataset.repo_id=user/highlight-data \\ + --dataset.single_task="pick up cube" + + # DAgger mode — human-in-the-loop corrections only lerobot-rollout \\ --strategy.type=dagger \\ + --strategy.num_episodes=20 \\ --policy.path=outputs/pretrain/checkpoints/last/pretrained_model \\ --robot.type=bi_openarm_follower \\ --teleop.type=openarm_mini \\ --dataset.repo_id=user/hil-data \\ --dataset.single_task="Fold the T-shirt" + + # DAgger mode — continuous recording with RTC inference + lerobot-rollout \\ + --strategy.type=dagger \\ + --strategy.record_autonomous=true \\ + --strategy.num_episodes=50 \\ + --inference.type=rtc \\ + --inference.rtc.execution_horizon=10 \\ + --policy.path=user/my_pi0_policy \\ + --robot.type=so100_follower \\ + --robot.port=/dev/ttyACM0 \\ + --teleop.type=so101_leader \\ + --teleop.port=/dev/ttyACM1 \\ + --dataset.repo_id=user/dagger-rtc-data \\ + --dataset.single_task="Grasp the block" + + # With Rerun visualization and torch.compile + lerobot-rollout \\ + --strategy.type=base \\ + --policy.path=lerobot/act_koch_real \\ + --robot.type=koch_follower \\ + --robot.port=/dev/ttyACM0 \\ + --task="pick up cube" --duration=60 \\ + --display_data=true \\ + --use_torch_compile=true + + # Resume a previous sentry recording session + lerobot-rollout \\ + --strategy.type=sentry \\ + --policy.path=user/my_policy \\ + --robot.type=so100_follower \\ + --robot.port=/dev/ttyACM0 \\ + --dataset.repo_id=user/sentry-data \\ + --dataset.single_task="patrol" \\ + --resume=true """ import logging diff --git a/tests/test_rollout.py b/tests/test_rollout.py new file mode 100644 index 000000000..9f54f581d --- /dev/null +++ b/tests/test_rollout.py @@ -0,0 +1,339 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Minimal tests for the rollout module's public API.""" + +from __future__ import annotations + +import dataclasses +from unittest.mock import MagicMock + +import pytest +import torch + +# --------------------------------------------------------------------------- +# Import smoke tests +# --------------------------------------------------------------------------- + + +def test_rollout_top_level_imports(): + import lerobot.rollout + + for name in lerobot.rollout.__all__: + assert hasattr(lerobot.rollout, name), f"Missing export: {name}" + + +def test_inference_submodule_imports(): + import lerobot.rollout.inference + + for name in lerobot.rollout.inference.__all__: + assert hasattr(lerobot.rollout.inference, name), f"Missing export: {name}" + + +def test_strategies_submodule_imports(): + import lerobot.rollout.strategies + + for name in lerobot.rollout.strategies.__all__: + assert hasattr(lerobot.rollout.strategies, name), f"Missing export: {name}" + + +# --------------------------------------------------------------------------- +# Config tests +# --------------------------------------------------------------------------- + + +def test_strategy_config_types(): + from lerobot.rollout.configs import ( + BaseStrategyConfig, + DAggerStrategyConfig, + HighlightStrategyConfig, + SentryStrategyConfig, + ) + + assert BaseStrategyConfig().type == "base" + assert SentryStrategyConfig().type == "sentry" + assert HighlightStrategyConfig().type == "highlight" + assert DAggerStrategyConfig().type == "dagger" + + +def test_dagger_config_invalid_input_device(): + from lerobot.rollout.configs import DAggerStrategyConfig + + with pytest.raises(ValueError, match="input_device must be 'keyboard' or 'pedal'"): + DAggerStrategyConfig(input_device="joystick") + + +def test_dagger_config_defaults(): + from lerobot.rollout.configs import DAggerStrategyConfig + + cfg = DAggerStrategyConfig() + assert cfg.num_episodes == 10 + assert cfg.record_autonomous is False + assert cfg.input_device == "keyboard" + + +def test_inference_config_types(): + from lerobot.rollout.inference import RTCInferenceConfig, SyncInferenceConfig + + assert SyncInferenceConfig().type == "sync" + + rtc = RTCInferenceConfig() + assert rtc.type == "rtc" + assert rtc.queue_threshold == 30 + assert rtc.rtc is not None + + +def test_sentry_config_defaults(): + from lerobot.rollout.configs import SentryStrategyConfig + + cfg = SentryStrategyConfig() + assert cfg.upload_every_n_episodes == 5 + assert cfg.target_video_file_size_mb is None + + +# --------------------------------------------------------------------------- +# RolloutRingBuffer +# --------------------------------------------------------------------------- + + +def test_ring_buffer_append_and_eviction(): + from lerobot.rollout.ring_buffer import RolloutRingBuffer + + buf = RolloutRingBuffer(max_seconds=0.5, max_memory_mb=100.0, fps=10.0) + # max_frames = 5 + for i in range(8): + buf.append({"val": i}) + assert len(buf) == 5 + + +def test_ring_buffer_drain(): + from lerobot.rollout.ring_buffer import RolloutRingBuffer + + buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0) + for i in range(3): + buf.append({"val": i}) + frames = buf.drain() + assert len(frames) == 3 + assert len(buf) == 0 + assert buf.estimated_bytes == 0 + + +def test_ring_buffer_clear(): + from lerobot.rollout.ring_buffer import RolloutRingBuffer + + buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0) + buf.append({"val": 1}) + buf.clear() + assert len(buf) == 0 + assert buf.estimated_bytes == 0 + + +def test_ring_buffer_tensor_bytes(): + from lerobot.rollout.ring_buffer import RolloutRingBuffer + + buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0) + t = torch.zeros(100, dtype=torch.float32) # 400 bytes + buf.append({"tensor": t}) + assert buf.estimated_bytes >= 400 + + +# --------------------------------------------------------------------------- +# ThreadSafeRobot +# --------------------------------------------------------------------------- + + +def test_thread_safe_robot_delegates(): + from lerobot.rollout.robot_wrapper import ThreadSafeRobot + from tests.mocks.mock_robot import MockRobot, MockRobotConfig + + robot = MockRobot(MockRobotConfig(n_motors=3)) + robot.connect() + wrapper = ThreadSafeRobot(robot) + + obs = wrapper.get_observation() + assert "motor_1.pos" in obs + assert "motor_2.pos" in obs + assert "motor_3.pos" in obs + + action = {"motor_1.pos": 0.0, "motor_2.pos": 1.0, "motor_3.pos": 2.0} + result = wrapper.send_action(action) + assert result == action + + robot.disconnect() + + +def test_thread_safe_robot_properties(): + from lerobot.rollout.robot_wrapper import ThreadSafeRobot + from tests.mocks.mock_robot import MockRobot, MockRobotConfig + + robot = MockRobot(MockRobotConfig(n_motors=3)) + robot.connect() + wrapper = ThreadSafeRobot(robot) + + assert wrapper.name == "mock_robot" + assert "motor_1.pos" in wrapper.observation_features + assert "motor_1.pos" in wrapper.action_features + assert wrapper.is_connected is True + assert wrapper.inner is robot + + robot.disconnect() + + +# --------------------------------------------------------------------------- +# Strategy factory +# --------------------------------------------------------------------------- + + +def test_create_strategy_dispatches(): + from lerobot.rollout.configs import BaseStrategyConfig, DAggerStrategyConfig, SentryStrategyConfig + from lerobot.rollout.strategies import create_strategy + from lerobot.rollout.strategies.base import BaseStrategy + from lerobot.rollout.strategies.dagger import DAggerStrategy + from lerobot.rollout.strategies.sentry import SentryStrategy + + assert isinstance(create_strategy(BaseStrategyConfig()), BaseStrategy) + assert isinstance(create_strategy(SentryStrategyConfig()), SentryStrategy) + assert isinstance(create_strategy(DAggerStrategyConfig()), DAggerStrategy) + + +def test_create_strategy_unknown_raises(): + from lerobot.rollout.strategies import create_strategy + + cfg = MagicMock() + cfg.type = "bogus" + with pytest.raises(ValueError, match="Unknown strategy type"): + create_strategy(cfg) + + +# --------------------------------------------------------------------------- +# Inference factory +# --------------------------------------------------------------------------- + + +def test_create_inference_engine_sync(): + from lerobot.rollout.inference import SyncInferenceConfig, SyncInferenceEngine, create_inference_engine + + engine = create_inference_engine( + SyncInferenceConfig(), + policy=MagicMock(), + preprocessor=MagicMock(), + postprocessor=MagicMock(), + robot_wrapper=MagicMock(robot_type="mock"), + hw_features={}, + dataset_features={}, + ordered_action_keys=["k"], + task="test", + fps=30.0, + device="cpu", + ) + assert isinstance(engine, SyncInferenceEngine) + + +# --------------------------------------------------------------------------- +# Pure functions +# --------------------------------------------------------------------------- + + +def test_estimate_max_episode_seconds_no_video(): + from lerobot.rollout.strategies import estimate_max_episode_seconds + + assert estimate_max_episode_seconds({}, fps=30.0) == 600.0 + + +def test_estimate_max_episode_seconds_with_video(): + from lerobot.rollout.strategies import estimate_max_episode_seconds + + features = {"cam": {"dtype": "video", "shape": (3, 480, 640)}} + result = estimate_max_episode_seconds(features, fps=30.0) + assert result > 0 + # With a real camera, duration should differ from the fallback + assert result != 600.0 + + +def test_safe_push_to_hub(): + from lerobot.rollout.strategies import safe_push_to_hub + + ds = MagicMock() + ds.num_episodes = 0 + assert safe_push_to_hub(ds) is False + ds.push_to_hub.assert_not_called() + + ds.num_episodes = 5 + assert safe_push_to_hub(ds, tags=["test"]) is True + ds.push_to_hub.assert_called_once_with(tags=["test"], private=False) + + +# --------------------------------------------------------------------------- +# DAgger state machine +# --------------------------------------------------------------------------- + + +def test_dagger_full_transition_cycle(): + from lerobot.rollout.strategies.dagger import DAggerEvents, DAggerPhase + + events = DAggerEvents() + assert events.phase == DAggerPhase.AUTONOMOUS + + # AUTONOMOUS -> PAUSED + events.request_transition("pause_resume") + old, new = events.consume_transition() + assert (old, new) == (DAggerPhase.AUTONOMOUS, DAggerPhase.PAUSED) + + # PAUSED -> CORRECTING + events.request_transition("correction") + old, new = events.consume_transition() + assert (old, new) == (DAggerPhase.PAUSED, DAggerPhase.CORRECTING) + + # CORRECTING -> PAUSED + events.request_transition("correction") + old, new = events.consume_transition() + assert (old, new) == (DAggerPhase.CORRECTING, DAggerPhase.PAUSED) + + # PAUSED -> AUTONOMOUS + events.request_transition("pause_resume") + old, new = events.consume_transition() + assert (old, new) == (DAggerPhase.PAUSED, DAggerPhase.AUTONOMOUS) + + +def test_dagger_invalid_transition_ignored(): + from lerobot.rollout.strategies.dagger import DAggerEvents, DAggerPhase + + events = DAggerEvents() + events.request_transition("correction") # Not valid from AUTONOMOUS + assert events.consume_transition() is None + assert events.phase == DAggerPhase.AUTONOMOUS + + +def test_dagger_events_reset(): + from lerobot.rollout.strategies.dagger import DAggerEvents, DAggerPhase + + events = DAggerEvents() + events.request_transition("pause_resume") + events.consume_transition() # -> PAUSED + events.upload_requested.set() + events.reset() + assert events.phase == DAggerPhase.AUTONOMOUS + assert not events.upload_requested.is_set() + + +# --------------------------------------------------------------------------- +# Context dataclass +# --------------------------------------------------------------------------- + + +def test_rollout_context_fields(): + from lerobot.rollout.context import RolloutContext + + field_names = {f.name for f in dataclasses.fields(RolloutContext)} + assert field_names == {"runtime", "hardware", "policy", "processors", "data"}