chore(rollout): multiple minor improvements

This commit is contained in:
Steven Palma
2026-04-22 15:18:50 +02:00
parent bff2d50dc1
commit d70c3baf7c
9 changed files with 61 additions and 29 deletions
+6 -1
View File
@@ -69,7 +69,12 @@ class DatasetRecordConfig:
# Lower values reduce CPU usage, maps to 'lp' (via svtav1-params) for libsvtav1 and 'threads' for h264/hevc.. # Lower values reduce CPU usage, maps to 'lp' (via svtav1-params) for libsvtav1 and 'threads' for h264/hevc..
encoder_threads: int | None = None encoder_threads: int | None = None
def __post_init__(self) -> None: def stamp_repo_id(self) -> None:
"""Append a date-time tag to ``repo_id`` so each recording session gets a unique name.
Must be called explicitly at dataset *creation* time — not on resume,
where the existing ``repo_id`` (already stamped) must be preserved.
"""
if self.repo_id: if self.repo_id:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
self.repo_id = f"{self.repo_id}_{timestamp}" self.repo_id = f"{self.repo_id}_{timestamp}"
+12 -5
View File
@@ -46,17 +46,25 @@ from .inference import (
SyncInferenceEngine, SyncInferenceEngine,
create_inference_engine, create_inference_engine,
) )
from .ring_buffer import RolloutRingBuffer from .strategies import (
from .robot_wrapper import ThreadSafeRobot BaseStrategy,
from .strategies import RolloutStrategy, create_strategy DAggerStrategy,
HighlightStrategy,
RolloutStrategy,
SentryStrategy,
create_strategy,
)
__all__ = [ __all__ = [
"BaseStrategy",
"BaseStrategyConfig", "BaseStrategyConfig",
"DAggerKeyboardConfig", "DAggerKeyboardConfig",
"DAggerPedalConfig", "DAggerPedalConfig",
"DAggerStrategy",
"DAggerStrategyConfig", "DAggerStrategyConfig",
"DatasetContext", "DatasetContext",
"HardwareContext", "HardwareContext",
"HighlightStrategy",
"HighlightStrategyConfig", "HighlightStrategyConfig",
"InferenceEngine", "InferenceEngine",
"InferenceEngineConfig", "InferenceEngineConfig",
@@ -66,14 +74,13 @@ __all__ = [
"RTCInferenceEngine", "RTCInferenceEngine",
"RolloutConfig", "RolloutConfig",
"RolloutContext", "RolloutContext",
"RolloutRingBuffer",
"RolloutStrategy", "RolloutStrategy",
"RolloutStrategyConfig", "RolloutStrategyConfig",
"RuntimeContext", "RuntimeContext",
"SentryStrategy",
"SentryStrategyConfig", "SentryStrategyConfig",
"SyncInferenceConfig", "SyncInferenceConfig",
"SyncInferenceEngine", "SyncInferenceEngine",
"ThreadSafeRobot",
"build_rollout_context", "build_rollout_context",
"create_inference_engine", "create_inference_engine",
"create_strategy", "create_strategy",
+2 -2
View File
@@ -89,8 +89,8 @@ class HighlightStrategyConfig(RolloutStrategyConfig):
again. again.
""" """
ring_buffer_seconds: float = 30.0 ring_buffer_seconds: float = 10.0
ring_buffer_max_memory_mb: float = 2048.0 ring_buffer_max_memory_mb: float = 1024.0
save_key: str = "s" save_key: str = "s"
push_key: str = "h" push_key: str = "h"
+6
View File
@@ -346,6 +346,12 @@ def build_rollout_context(
"names": None, "names": None,
} }
if not cfg.dataset.repo_id.startswith("rollout_"):
raise ValueError(
"Dataset names for rollout must start with 'rollout_'. "
"Use --dataset.repo_id=rollout_<name> for policy deployment datasets."
)
cfg.dataset.stamp_repo_id()
dataset = LeRobotDataset.create( dataset = LeRobotDataset.create(
cfg.dataset.repo_id, cfg.dataset.repo_id,
cfg.dataset.fps, cfg.dataset.fps,
+3 -1
View File
@@ -267,7 +267,9 @@ def send_next_action(
if interp is None: if interp is None:
return None return None
action_dict = {k: interp[i].item() for i, k in enumerate(ordered_keys) if i < len(interp)} if len(interp) != len(ordered_keys):
raise ValueError(f"Interpolated tensor length ({len(interp)}) != action keys ({len(ordered_keys)})")
action_dict = {k: interp[i].item() for i, k in enumerate(ordered_keys)}
processed = ctx.processors.robot_action_processor((action_dict, obs_raw)) processed = ctx.processors.robot_action_processor((action_dict, obs_raw))
ctx.hardware.robot_wrapper.send_action(processed) ctx.hardware.robot_wrapper.send_action(processed)
return action_dict return action_dict
+1
View File
@@ -693,6 +693,7 @@ class DAggerStrategy(RolloutStrategy):
} }
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or # TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
# user is responsible for moving the teleop to the same position as the robot when starting the correction. # user is responsible for moving the teleop to the same position as the robot when starting the correction.
# Consider also a method that moves the robot to the teleop smoothly (similar to what we do at HW shutdown).
# _teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50) # _teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
elif new_phase == DAggerPhase.CORRECTING: elif new_phase == DAggerPhase.CORRECTING:
+12 -9
View File
@@ -22,7 +22,7 @@ import os
import sys import sys
import time import time
from concurrent.futures import Future, ThreadPoolExecutor from concurrent.futures import Future, ThreadPoolExecutor
from threading import Event as ThreadingEvent from threading import Event as ThreadingEvent, Lock
from lerobot.common.control_utils import is_headless from lerobot.common.control_utils import is_headless
from lerobot.datasets import VideoEncodingManager from lerobot.datasets import VideoEncodingManager
@@ -80,6 +80,7 @@ class HighlightStrategy(RolloutStrategy):
self._push_requested = ThreadingEvent() self._push_requested = ThreadingEvent()
self._push_executor: ThreadPoolExecutor | None = None self._push_executor: ThreadPoolExecutor | None = None
self._pending_push: Future | None = None self._pending_push: Future | None = None
self._episode_lock = Lock()
def setup(self, ctx: RolloutContext) -> None: def setup(self, ctx: RolloutContext) -> None:
"""Initialise the inference engine, ring buffer, and keyboard listener.""" """Initialise the inference engine, ring buffer, and keyboard listener."""
@@ -168,7 +169,8 @@ class HighlightStrategy(RolloutStrategy):
self._recording_live.set() self._recording_live.set()
else: else:
dataset.add_frame(frame) dataset.add_frame(frame)
dataset.save_episode() with self._episode_lock:
dataset.save_episode()
logger.info("Episode saved (total: %d)", dataset.num_episodes) logger.info("Episode saved (total: %d)", dataset.num_episodes)
log_say( log_say(
f"Episode {dataset.num_episodes} saved", f"Episode {dataset.num_episodes} saved",
@@ -198,7 +200,7 @@ class HighlightStrategy(RolloutStrategy):
logger.info("Highlight control loop ended") logger.info("Highlight control loop ended")
if self._recording_live.is_set(): if self._recording_live.is_set():
logger.info("Saving in-progress live episode") logger.info("Saving in-progress live episode")
with contextlib.suppress(Exception): with contextlib.suppress(Exception), self._episode_lock:
dataset.save_episode() dataset.save_episode()
def teardown(self, ctx: RolloutContext) -> None: def teardown(self, ctx: RolloutContext) -> None:
@@ -268,12 +270,13 @@ class HighlightStrategy(RolloutStrategy):
def _push(): def _push():
try: try:
if safe_push_to_hub( with self._episode_lock:
dataset, if safe_push_to_hub(
tags=cfg.dataset.tags if cfg.dataset else None, dataset,
private=cfg.dataset.private if cfg.dataset else False, tags=cfg.dataset.tags if cfg.dataset else None,
): private=cfg.dataset.private if cfg.dataset else False,
logger.info("Background push to hub complete") ):
logger.info("Background push to hub complete")
except Exception as e: except Exception as e:
logger.error("Background push failed: %s", e) logger.error("Background push failed: %s", e)
+1
View File
@@ -394,6 +394,7 @@ def record(
"Dataset names starting with 'eval_' are reserved for policy evaluation. " "Dataset names starting with 'eval_' are reserved for policy evaluation. "
"lerobot-record is for data collection only. Use lerobot-rollout for policy deployment." "lerobot-record is for data collection only. Use lerobot-rollout for policy deployment."
) )
cfg.dataset.stamp_repo_id()
dataset = LeRobotDataset.create( dataset = LeRobotDataset.create(
cfg.dataset.repo_id, cfg.dataset.repo_id,
cfg.dataset.fps, cfg.dataset.fps,
+18 -11
View File
@@ -86,7 +86,7 @@ def test_dagger_config_defaults():
def test_inference_config_types(): def test_inference_config_types():
from lerobot.rollout.inference import RTCInferenceConfig, SyncInferenceConfig from lerobot.rollout import RTCInferenceConfig, SyncInferenceConfig
assert SyncInferenceConfig().type == "sync" assert SyncInferenceConfig().type == "sync"
@@ -110,7 +110,7 @@ def test_sentry_config_defaults():
def test_ring_buffer_append_and_eviction(): def test_ring_buffer_append_and_eviction():
from lerobot.rollout import RolloutRingBuffer from lerobot.rollout.ring_buffer import RolloutRingBuffer
buf = RolloutRingBuffer(max_seconds=0.5, max_memory_mb=100.0, fps=10.0) buf = RolloutRingBuffer(max_seconds=0.5, max_memory_mb=100.0, fps=10.0)
# max_frames = 5 # max_frames = 5
@@ -120,7 +120,7 @@ def test_ring_buffer_append_and_eviction():
def test_ring_buffer_drain(): def test_ring_buffer_drain():
from lerobot.rollout import RolloutRingBuffer from lerobot.rollout.ring_buffer import RolloutRingBuffer
buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0) buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0)
for i in range(3): for i in range(3):
@@ -132,7 +132,7 @@ def test_ring_buffer_drain():
def test_ring_buffer_clear(): def test_ring_buffer_clear():
from lerobot.rollout import RolloutRingBuffer from lerobot.rollout.ring_buffer import RolloutRingBuffer
buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0) buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0)
buf.append({"val": 1}) buf.append({"val": 1})
@@ -142,7 +142,7 @@ def test_ring_buffer_clear():
def test_ring_buffer_tensor_bytes(): def test_ring_buffer_tensor_bytes():
from lerobot.rollout import RolloutRingBuffer from lerobot.rollout.ring_buffer import RolloutRingBuffer
buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0) buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0)
t = torch.zeros(100, dtype=torch.float32) # 400 bytes t = torch.zeros(100, dtype=torch.float32) # 400 bytes
@@ -156,7 +156,7 @@ def test_ring_buffer_tensor_bytes():
def test_thread_safe_robot_delegates(): def test_thread_safe_robot_delegates():
from lerobot.rollout import ThreadSafeRobot from lerobot.rollout.robot_wrapper import ThreadSafeRobot
from tests.mocks.mock_robot import MockRobot, MockRobotConfig from tests.mocks.mock_robot import MockRobot, MockRobotConfig
robot = MockRobot(MockRobotConfig(n_motors=3)) robot = MockRobot(MockRobotConfig(n_motors=3))
@@ -176,7 +176,7 @@ def test_thread_safe_robot_delegates():
def test_thread_safe_robot_properties(): def test_thread_safe_robot_properties():
from lerobot.rollout import ThreadSafeRobot from lerobot.rollout.robot_wrapper import ThreadSafeRobot
from tests.mocks.mock_robot import MockRobot, MockRobotConfig from tests.mocks.mock_robot import MockRobot, MockRobotConfig
robot = MockRobot(MockRobotConfig(n_motors=3)) robot = MockRobot(MockRobotConfig(n_motors=3))
@@ -198,8 +198,15 @@ def test_thread_safe_robot_properties():
def test_create_strategy_dispatches(): def test_create_strategy_dispatches():
from lerobot.rollout import BaseStrategyConfig, DAggerStrategyConfig, SentryStrategyConfig from lerobot.rollout import (
from lerobot.rollout.strategies import BaseStrategy, DAggerStrategy, SentryStrategy, create_strategy BaseStrategy,
BaseStrategyConfig,
DAggerStrategy,
DAggerStrategyConfig,
SentryStrategy,
SentryStrategyConfig,
create_strategy,
)
assert isinstance(create_strategy(BaseStrategyConfig()), BaseStrategy) assert isinstance(create_strategy(BaseStrategyConfig()), BaseStrategy)
assert isinstance(create_strategy(SentryStrategyConfig()), SentryStrategy) assert isinstance(create_strategy(SentryStrategyConfig()), SentryStrategy)
@@ -207,7 +214,7 @@ def test_create_strategy_dispatches():
def test_create_strategy_unknown_raises(): def test_create_strategy_unknown_raises():
from lerobot.rollout.strategies import create_strategy from lerobot.rollout import create_strategy
cfg = MagicMock() cfg = MagicMock()
cfg.type = "bogus" cfg.type = "bogus"
@@ -221,7 +228,7 @@ def test_create_strategy_unknown_raises():
def test_create_inference_engine_sync(): def test_create_inference_engine_sync():
from lerobot.rollout.inference import SyncInferenceConfig, SyncInferenceEngine, create_inference_engine from lerobot.rollout import SyncInferenceConfig, SyncInferenceEngine, create_inference_engine
engine = create_inference_engine( engine = create_inference_engine(
SyncInferenceConfig(), SyncInferenceConfig(),