chore(rollout): multiple minor improvements

This commit is contained in:
Steven Palma
2026-04-22 15:18:50 +02:00
parent bff2d50dc1
commit d70c3baf7c
9 changed files with 61 additions and 29 deletions
+6 -1
View File
@@ -69,7 +69,12 @@ class DatasetRecordConfig:
# Lower values reduce CPU usage, maps to 'lp' (via svtav1-params) for libsvtav1 and 'threads' for h264/hevc..
encoder_threads: int | None = None
def __post_init__(self) -> None:
def stamp_repo_id(self) -> None:
"""Append a date-time tag to ``repo_id`` so each recording session gets a unique name.
Must be called explicitly at dataset *creation* time — not on resume,
where the existing ``repo_id`` (already stamped) must be preserved.
"""
if self.repo_id:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
self.repo_id = f"{self.repo_id}_{timestamp}"
+12 -5
View File
@@ -46,17 +46,25 @@ from .inference import (
SyncInferenceEngine,
create_inference_engine,
)
from .ring_buffer import RolloutRingBuffer
from .robot_wrapper import ThreadSafeRobot
from .strategies import RolloutStrategy, create_strategy
from .strategies import (
BaseStrategy,
DAggerStrategy,
HighlightStrategy,
RolloutStrategy,
SentryStrategy,
create_strategy,
)
__all__ = [
"BaseStrategy",
"BaseStrategyConfig",
"DAggerKeyboardConfig",
"DAggerPedalConfig",
"DAggerStrategy",
"DAggerStrategyConfig",
"DatasetContext",
"HardwareContext",
"HighlightStrategy",
"HighlightStrategyConfig",
"InferenceEngine",
"InferenceEngineConfig",
@@ -66,14 +74,13 @@ __all__ = [
"RTCInferenceEngine",
"RolloutConfig",
"RolloutContext",
"RolloutRingBuffer",
"RolloutStrategy",
"RolloutStrategyConfig",
"RuntimeContext",
"SentryStrategy",
"SentryStrategyConfig",
"SyncInferenceConfig",
"SyncInferenceEngine",
"ThreadSafeRobot",
"build_rollout_context",
"create_inference_engine",
"create_strategy",
+2 -2
View File
@@ -89,8 +89,8 @@ class HighlightStrategyConfig(RolloutStrategyConfig):
again.
"""
ring_buffer_seconds: float = 30.0
ring_buffer_max_memory_mb: float = 2048.0
ring_buffer_seconds: float = 10.0
ring_buffer_max_memory_mb: float = 1024.0
save_key: str = "s"
push_key: str = "h"
+6
View File
@@ -346,6 +346,12 @@ def build_rollout_context(
"names": None,
}
if not cfg.dataset.repo_id.startswith("rollout_"):
raise ValueError(
"Dataset names for rollout must start with 'rollout_'. "
"Use --dataset.repo_id=rollout_<name> for policy deployment datasets."
)
cfg.dataset.stamp_repo_id()
dataset = LeRobotDataset.create(
cfg.dataset.repo_id,
cfg.dataset.fps,
+3 -1
View File
@@ -267,7 +267,9 @@ def send_next_action(
if interp is None:
return None
action_dict = {k: interp[i].item() for i, k in enumerate(ordered_keys) if i < len(interp)}
if len(interp) != len(ordered_keys):
raise ValueError(f"Interpolated tensor length ({len(interp)}) != action keys ({len(ordered_keys)})")
action_dict = {k: interp[i].item() for i, k in enumerate(ordered_keys)}
processed = ctx.processors.robot_action_processor((action_dict, obs_raw))
ctx.hardware.robot_wrapper.send_action(processed)
return action_dict
+1
View File
@@ -693,6 +693,7 @@ class DAggerStrategy(RolloutStrategy):
}
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
# Consider also a method that moves the robot to the teleop smoothly (similar to what we do at HW shutdown).
# _teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
elif new_phase == DAggerPhase.CORRECTING:
+12 -9
View File
@@ -22,7 +22,7 @@ import os
import sys
import time
from concurrent.futures import Future, ThreadPoolExecutor
from threading import Event as ThreadingEvent
from threading import Event as ThreadingEvent, Lock
from lerobot.common.control_utils import is_headless
from lerobot.datasets import VideoEncodingManager
@@ -80,6 +80,7 @@ class HighlightStrategy(RolloutStrategy):
self._push_requested = ThreadingEvent()
self._push_executor: ThreadPoolExecutor | None = None
self._pending_push: Future | None = None
self._episode_lock = Lock()
def setup(self, ctx: RolloutContext) -> None:
"""Initialise the inference engine, ring buffer, and keyboard listener."""
@@ -168,7 +169,8 @@ class HighlightStrategy(RolloutStrategy):
self._recording_live.set()
else:
dataset.add_frame(frame)
dataset.save_episode()
with self._episode_lock:
dataset.save_episode()
logger.info("Episode saved (total: %d)", dataset.num_episodes)
log_say(
f"Episode {dataset.num_episodes} saved",
@@ -198,7 +200,7 @@ class HighlightStrategy(RolloutStrategy):
logger.info("Highlight control loop ended")
if self._recording_live.is_set():
logger.info("Saving in-progress live episode")
with contextlib.suppress(Exception):
with contextlib.suppress(Exception), self._episode_lock:
dataset.save_episode()
def teardown(self, ctx: RolloutContext) -> None:
@@ -268,12 +270,13 @@ class HighlightStrategy(RolloutStrategy):
def _push():
try:
if safe_push_to_hub(
dataset,
tags=cfg.dataset.tags if cfg.dataset else None,
private=cfg.dataset.private if cfg.dataset else False,
):
logger.info("Background push to hub complete")
with self._episode_lock:
if safe_push_to_hub(
dataset,
tags=cfg.dataset.tags if cfg.dataset else None,
private=cfg.dataset.private if cfg.dataset else False,
):
logger.info("Background push to hub complete")
except Exception as e:
logger.error("Background push failed: %s", e)
+1
View File
@@ -394,6 +394,7 @@ def record(
"Dataset names starting with 'eval_' are reserved for policy evaluation. "
"lerobot-record is for data collection only. Use lerobot-rollout for policy deployment."
)
cfg.dataset.stamp_repo_id()
dataset = LeRobotDataset.create(
cfg.dataset.repo_id,
cfg.dataset.fps,
+18 -11
View File
@@ -86,7 +86,7 @@ def test_dagger_config_defaults():
def test_inference_config_types():
from lerobot.rollout.inference import RTCInferenceConfig, SyncInferenceConfig
from lerobot.rollout import RTCInferenceConfig, SyncInferenceConfig
assert SyncInferenceConfig().type == "sync"
@@ -110,7 +110,7 @@ def test_sentry_config_defaults():
def test_ring_buffer_append_and_eviction():
from lerobot.rollout import RolloutRingBuffer
from lerobot.rollout.ring_buffer import RolloutRingBuffer
buf = RolloutRingBuffer(max_seconds=0.5, max_memory_mb=100.0, fps=10.0)
# max_frames = 5
@@ -120,7 +120,7 @@ def test_ring_buffer_append_and_eviction():
def test_ring_buffer_drain():
from lerobot.rollout import RolloutRingBuffer
from lerobot.rollout.ring_buffer import RolloutRingBuffer
buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0)
for i in range(3):
@@ -132,7 +132,7 @@ def test_ring_buffer_drain():
def test_ring_buffer_clear():
from lerobot.rollout import RolloutRingBuffer
from lerobot.rollout.ring_buffer import RolloutRingBuffer
buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0)
buf.append({"val": 1})
@@ -142,7 +142,7 @@ def test_ring_buffer_clear():
def test_ring_buffer_tensor_bytes():
from lerobot.rollout import RolloutRingBuffer
from lerobot.rollout.ring_buffer import RolloutRingBuffer
buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0)
t = torch.zeros(100, dtype=torch.float32) # 400 bytes
@@ -156,7 +156,7 @@ def test_ring_buffer_tensor_bytes():
def test_thread_safe_robot_delegates():
from lerobot.rollout import ThreadSafeRobot
from lerobot.rollout.robot_wrapper import ThreadSafeRobot
from tests.mocks.mock_robot import MockRobot, MockRobotConfig
robot = MockRobot(MockRobotConfig(n_motors=3))
@@ -176,7 +176,7 @@ def test_thread_safe_robot_delegates():
def test_thread_safe_robot_properties():
from lerobot.rollout import ThreadSafeRobot
from lerobot.rollout.robot_wrapper import ThreadSafeRobot
from tests.mocks.mock_robot import MockRobot, MockRobotConfig
robot = MockRobot(MockRobotConfig(n_motors=3))
@@ -198,8 +198,15 @@ def test_thread_safe_robot_properties():
def test_create_strategy_dispatches():
from lerobot.rollout import BaseStrategyConfig, DAggerStrategyConfig, SentryStrategyConfig
from lerobot.rollout.strategies import BaseStrategy, DAggerStrategy, SentryStrategy, create_strategy
from lerobot.rollout import (
BaseStrategy,
BaseStrategyConfig,
DAggerStrategy,
DAggerStrategyConfig,
SentryStrategy,
SentryStrategyConfig,
create_strategy,
)
assert isinstance(create_strategy(BaseStrategyConfig()), BaseStrategy)
assert isinstance(create_strategy(SentryStrategyConfig()), SentryStrategy)
@@ -207,7 +214,7 @@ def test_create_strategy_dispatches():
def test_create_strategy_unknown_raises():
from lerobot.rollout.strategies import create_strategy
from lerobot.rollout import create_strategy
cfg = MagicMock()
cfg.type = "bogus"
@@ -221,7 +228,7 @@ def test_create_strategy_unknown_raises():
def test_create_inference_engine_sync():
from lerobot.rollout.inference import SyncInferenceConfig, SyncInferenceEngine, create_inference_engine
from lerobot.rollout import SyncInferenceConfig, SyncInferenceEngine, create_inference_engine
engine = create_inference_engine(
SyncInferenceConfig(),