From d70c3baf7c91f92d44f5c0df7d1002715c4d17ef Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 22 Apr 2026 15:18:50 +0200 Subject: [PATCH] chore(rollout): multiple minor improvements --- src/lerobot/configs/dataset.py | 7 ++++- src/lerobot/rollout/__init__.py | 17 ++++++++---- src/lerobot/rollout/configs.py | 4 +-- src/lerobot/rollout/context.py | 6 +++++ src/lerobot/rollout/strategies/core.py | 4 ++- src/lerobot/rollout/strategies/dagger.py | 1 + src/lerobot/rollout/strategies/highlight.py | 21 ++++++++------- src/lerobot/scripts/lerobot_record.py | 1 + tests/test_rollout.py | 29 +++++++++++++-------- 9 files changed, 61 insertions(+), 29 deletions(-) diff --git a/src/lerobot/configs/dataset.py b/src/lerobot/configs/dataset.py index e359aadc7..e3e17e62b 100644 --- a/src/lerobot/configs/dataset.py +++ b/src/lerobot/configs/dataset.py @@ -69,7 +69,12 @@ class DatasetRecordConfig: # Lower values reduce CPU usage, maps to 'lp' (via svtav1-params) for libsvtav1 and 'threads' for h264/hevc.. encoder_threads: int | None = None - def __post_init__(self) -> None: + def stamp_repo_id(self) -> None: + """Append a date-time tag to ``repo_id`` so each recording session gets a unique name. + + Must be called explicitly at dataset *creation* time — not on resume, + where the existing ``repo_id`` (already stamped) must be preserved. + """ if self.repo_id: timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") self.repo_id = f"{self.repo_id}_{timestamp}" diff --git a/src/lerobot/rollout/__init__.py b/src/lerobot/rollout/__init__.py index da8c18acc..a4de8ee6c 100644 --- a/src/lerobot/rollout/__init__.py +++ b/src/lerobot/rollout/__init__.py @@ -46,17 +46,25 @@ from .inference import ( SyncInferenceEngine, create_inference_engine, ) -from .ring_buffer import RolloutRingBuffer -from .robot_wrapper import ThreadSafeRobot -from .strategies import RolloutStrategy, create_strategy +from .strategies import ( + BaseStrategy, + DAggerStrategy, + HighlightStrategy, + RolloutStrategy, + SentryStrategy, + create_strategy, +) __all__ = [ + "BaseStrategy", "BaseStrategyConfig", "DAggerKeyboardConfig", "DAggerPedalConfig", + "DAggerStrategy", "DAggerStrategyConfig", "DatasetContext", "HardwareContext", + "HighlightStrategy", "HighlightStrategyConfig", "InferenceEngine", "InferenceEngineConfig", @@ -66,14 +74,13 @@ __all__ = [ "RTCInferenceEngine", "RolloutConfig", "RolloutContext", - "RolloutRingBuffer", "RolloutStrategy", "RolloutStrategyConfig", "RuntimeContext", + "SentryStrategy", "SentryStrategyConfig", "SyncInferenceConfig", "SyncInferenceEngine", - "ThreadSafeRobot", "build_rollout_context", "create_inference_engine", "create_strategy", diff --git a/src/lerobot/rollout/configs.py b/src/lerobot/rollout/configs.py index 2ee3122dc..f841cc040 100644 --- a/src/lerobot/rollout/configs.py +++ b/src/lerobot/rollout/configs.py @@ -89,8 +89,8 @@ class HighlightStrategyConfig(RolloutStrategyConfig): again. """ - ring_buffer_seconds: float = 30.0 - ring_buffer_max_memory_mb: float = 2048.0 + ring_buffer_seconds: float = 10.0 + ring_buffer_max_memory_mb: float = 1024.0 save_key: str = "s" push_key: str = "h" diff --git a/src/lerobot/rollout/context.py b/src/lerobot/rollout/context.py index ee21be56e..a3bb0cb6f 100644 --- a/src/lerobot/rollout/context.py +++ b/src/lerobot/rollout/context.py @@ -346,6 +346,12 @@ def build_rollout_context( "names": None, } + if not cfg.dataset.repo_id.startswith("rollout_"): + raise ValueError( + "Dataset names for rollout must start with 'rollout_'. " + "Use --dataset.repo_id=rollout_ for policy deployment datasets." + ) + cfg.dataset.stamp_repo_id() dataset = LeRobotDataset.create( cfg.dataset.repo_id, cfg.dataset.fps, diff --git a/src/lerobot/rollout/strategies/core.py b/src/lerobot/rollout/strategies/core.py index 5c336ea3a..3cc8f5f53 100644 --- a/src/lerobot/rollout/strategies/core.py +++ b/src/lerobot/rollout/strategies/core.py @@ -267,7 +267,9 @@ def send_next_action( if interp is None: return None - action_dict = {k: interp[i].item() for i, k in enumerate(ordered_keys) if i < len(interp)} + if len(interp) != len(ordered_keys): + raise ValueError(f"Interpolated tensor length ({len(interp)}) != action keys ({len(ordered_keys)})") + action_dict = {k: interp[i].item() for i, k in enumerate(ordered_keys)} processed = ctx.processors.robot_action_processor((action_dict, obs_raw)) ctx.hardware.robot_wrapper.send_action(processed) return action_dict diff --git a/src/lerobot/rollout/strategies/dagger.py b/src/lerobot/rollout/strategies/dagger.py index 66fadee27..d9de631c5 100644 --- a/src/lerobot/rollout/strategies/dagger.py +++ b/src/lerobot/rollout/strategies/dagger.py @@ -693,6 +693,7 @@ class DAggerStrategy(RolloutStrategy): } # TODO(Steven): either enforce this (meaning all teleop must implement these methods) or # user is responsible for moving the teleop to the same position as the robot when starting the correction. + # Consider also a method that moves the robot to the teleop smoothly (similar to what we do at HW shutdown). # _teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50) elif new_phase == DAggerPhase.CORRECTING: diff --git a/src/lerobot/rollout/strategies/highlight.py b/src/lerobot/rollout/strategies/highlight.py index 1d3a7e55a..734f17005 100644 --- a/src/lerobot/rollout/strategies/highlight.py +++ b/src/lerobot/rollout/strategies/highlight.py @@ -22,7 +22,7 @@ import os import sys import time from concurrent.futures import Future, ThreadPoolExecutor -from threading import Event as ThreadingEvent +from threading import Event as ThreadingEvent, Lock from lerobot.common.control_utils import is_headless from lerobot.datasets import VideoEncodingManager @@ -80,6 +80,7 @@ class HighlightStrategy(RolloutStrategy): self._push_requested = ThreadingEvent() self._push_executor: ThreadPoolExecutor | None = None self._pending_push: Future | None = None + self._episode_lock = Lock() def setup(self, ctx: RolloutContext) -> None: """Initialise the inference engine, ring buffer, and keyboard listener.""" @@ -168,7 +169,8 @@ class HighlightStrategy(RolloutStrategy): self._recording_live.set() else: dataset.add_frame(frame) - dataset.save_episode() + with self._episode_lock: + dataset.save_episode() logger.info("Episode saved (total: %d)", dataset.num_episodes) log_say( f"Episode {dataset.num_episodes} saved", @@ -198,7 +200,7 @@ class HighlightStrategy(RolloutStrategy): logger.info("Highlight control loop ended") if self._recording_live.is_set(): logger.info("Saving in-progress live episode") - with contextlib.suppress(Exception): + with contextlib.suppress(Exception), self._episode_lock: dataset.save_episode() def teardown(self, ctx: RolloutContext) -> None: @@ -268,12 +270,13 @@ class HighlightStrategy(RolloutStrategy): def _push(): try: - if safe_push_to_hub( - dataset, - tags=cfg.dataset.tags if cfg.dataset else None, - private=cfg.dataset.private if cfg.dataset else False, - ): - logger.info("Background push to hub complete") + with self._episode_lock: + if safe_push_to_hub( + dataset, + tags=cfg.dataset.tags if cfg.dataset else None, + private=cfg.dataset.private if cfg.dataset else False, + ): + logger.info("Background push to hub complete") except Exception as e: logger.error("Background push failed: %s", e) diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index ac80188c5..1a4b1ea66 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -394,6 +394,7 @@ def record( "Dataset names starting with 'eval_' are reserved for policy evaluation. " "lerobot-record is for data collection only. Use lerobot-rollout for policy deployment." ) + cfg.dataset.stamp_repo_id() dataset = LeRobotDataset.create( cfg.dataset.repo_id, cfg.dataset.fps, diff --git a/tests/test_rollout.py b/tests/test_rollout.py index dd3fb25b6..df6963717 100644 --- a/tests/test_rollout.py +++ b/tests/test_rollout.py @@ -86,7 +86,7 @@ def test_dagger_config_defaults(): def test_inference_config_types(): - from lerobot.rollout.inference import RTCInferenceConfig, SyncInferenceConfig + from lerobot.rollout import RTCInferenceConfig, SyncInferenceConfig assert SyncInferenceConfig().type == "sync" @@ -110,7 +110,7 @@ def test_sentry_config_defaults(): def test_ring_buffer_append_and_eviction(): - from lerobot.rollout import RolloutRingBuffer + from lerobot.rollout.ring_buffer import RolloutRingBuffer buf = RolloutRingBuffer(max_seconds=0.5, max_memory_mb=100.0, fps=10.0) # max_frames = 5 @@ -120,7 +120,7 @@ def test_ring_buffer_append_and_eviction(): def test_ring_buffer_drain(): - from lerobot.rollout import RolloutRingBuffer + from lerobot.rollout.ring_buffer import RolloutRingBuffer buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0) for i in range(3): @@ -132,7 +132,7 @@ def test_ring_buffer_drain(): def test_ring_buffer_clear(): - from lerobot.rollout import RolloutRingBuffer + from lerobot.rollout.ring_buffer import RolloutRingBuffer buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0) buf.append({"val": 1}) @@ -142,7 +142,7 @@ def test_ring_buffer_clear(): def test_ring_buffer_tensor_bytes(): - from lerobot.rollout import RolloutRingBuffer + from lerobot.rollout.ring_buffer import RolloutRingBuffer buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0) t = torch.zeros(100, dtype=torch.float32) # 400 bytes @@ -156,7 +156,7 @@ def test_ring_buffer_tensor_bytes(): def test_thread_safe_robot_delegates(): - from lerobot.rollout import ThreadSafeRobot + from lerobot.rollout.robot_wrapper import ThreadSafeRobot from tests.mocks.mock_robot import MockRobot, MockRobotConfig robot = MockRobot(MockRobotConfig(n_motors=3)) @@ -176,7 +176,7 @@ def test_thread_safe_robot_delegates(): def test_thread_safe_robot_properties(): - from lerobot.rollout import ThreadSafeRobot + from lerobot.rollout.robot_wrapper import ThreadSafeRobot from tests.mocks.mock_robot import MockRobot, MockRobotConfig robot = MockRobot(MockRobotConfig(n_motors=3)) @@ -198,8 +198,15 @@ def test_thread_safe_robot_properties(): def test_create_strategy_dispatches(): - from lerobot.rollout import BaseStrategyConfig, DAggerStrategyConfig, SentryStrategyConfig - from lerobot.rollout.strategies import BaseStrategy, DAggerStrategy, SentryStrategy, create_strategy + from lerobot.rollout import ( + BaseStrategy, + BaseStrategyConfig, + DAggerStrategy, + DAggerStrategyConfig, + SentryStrategy, + SentryStrategyConfig, + create_strategy, + ) assert isinstance(create_strategy(BaseStrategyConfig()), BaseStrategy) assert isinstance(create_strategy(SentryStrategyConfig()), SentryStrategy) @@ -207,7 +214,7 @@ def test_create_strategy_dispatches(): def test_create_strategy_unknown_raises(): - from lerobot.rollout.strategies import create_strategy + from lerobot.rollout import create_strategy cfg = MagicMock() cfg.type = "bogus" @@ -221,7 +228,7 @@ def test_create_strategy_unknown_raises(): def test_create_inference_engine_sync(): - from lerobot.rollout.inference import SyncInferenceConfig, SyncInferenceEngine, create_inference_engine + from lerobot.rollout import SyncInferenceConfig, SyncInferenceEngine, create_inference_engine engine = create_inference_engine( SyncInferenceConfig(),