chore(rollout): multiple minor improvements

This commit is contained in:
Steven Palma
2026-04-22 15:18:50 +02:00
parent bff2d50dc1
commit d70c3baf7c
9 changed files with 61 additions and 29 deletions
+6 -1
View File
@@ -69,7 +69,12 @@ class DatasetRecordConfig:
# Lower values reduce CPU usage, maps to 'lp' (via svtav1-params) for libsvtav1 and 'threads' for h264/hevc..
encoder_threads: int | None = None
def __post_init__(self) -> None:
def stamp_repo_id(self) -> None:
"""Append a date-time tag to ``repo_id`` so each recording session gets a unique name.
Must be called explicitly at dataset *creation* time — not on resume,
where the existing ``repo_id`` (already stamped) must be preserved.
"""
if self.repo_id:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
self.repo_id = f"{self.repo_id}_{timestamp}"
+12 -5
View File
@@ -46,17 +46,25 @@ from .inference import (
SyncInferenceEngine,
create_inference_engine,
)
from .ring_buffer import RolloutRingBuffer
from .robot_wrapper import ThreadSafeRobot
from .strategies import RolloutStrategy, create_strategy
from .strategies import (
BaseStrategy,
DAggerStrategy,
HighlightStrategy,
RolloutStrategy,
SentryStrategy,
create_strategy,
)
__all__ = [
"BaseStrategy",
"BaseStrategyConfig",
"DAggerKeyboardConfig",
"DAggerPedalConfig",
"DAggerStrategy",
"DAggerStrategyConfig",
"DatasetContext",
"HardwareContext",
"HighlightStrategy",
"HighlightStrategyConfig",
"InferenceEngine",
"InferenceEngineConfig",
@@ -66,14 +74,13 @@ __all__ = [
"RTCInferenceEngine",
"RolloutConfig",
"RolloutContext",
"RolloutRingBuffer",
"RolloutStrategy",
"RolloutStrategyConfig",
"RuntimeContext",
"SentryStrategy",
"SentryStrategyConfig",
"SyncInferenceConfig",
"SyncInferenceEngine",
"ThreadSafeRobot",
"build_rollout_context",
"create_inference_engine",
"create_strategy",
+2 -2
View File
@@ -89,8 +89,8 @@ class HighlightStrategyConfig(RolloutStrategyConfig):
again.
"""
ring_buffer_seconds: float = 30.0
ring_buffer_max_memory_mb: float = 2048.0
ring_buffer_seconds: float = 10.0
ring_buffer_max_memory_mb: float = 1024.0
save_key: str = "s"
push_key: str = "h"
+6
View File
@@ -346,6 +346,12 @@ def build_rollout_context(
"names": None,
}
if not cfg.dataset.repo_id.startswith("rollout_"):
raise ValueError(
"Dataset names for rollout must start with 'rollout_'. "
"Use --dataset.repo_id=rollout_<name> for policy deployment datasets."
)
cfg.dataset.stamp_repo_id()
dataset = LeRobotDataset.create(
cfg.dataset.repo_id,
cfg.dataset.fps,
+3 -1
View File
@@ -267,7 +267,9 @@ def send_next_action(
if interp is None:
return None
action_dict = {k: interp[i].item() for i, k in enumerate(ordered_keys) if i < len(interp)}
if len(interp) != len(ordered_keys):
raise ValueError(f"Interpolated tensor length ({len(interp)}) != action keys ({len(ordered_keys)})")
action_dict = {k: interp[i].item() for i, k in enumerate(ordered_keys)}
processed = ctx.processors.robot_action_processor((action_dict, obs_raw))
ctx.hardware.robot_wrapper.send_action(processed)
return action_dict
+1
View File
@@ -693,6 +693,7 @@ class DAggerStrategy(RolloutStrategy):
}
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
# Consider also a method that moves the robot to the teleop smoothly (similar to what we do at HW shutdown).
# _teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
elif new_phase == DAggerPhase.CORRECTING:
+12 -9
View File
@@ -22,7 +22,7 @@ import os
import sys
import time
from concurrent.futures import Future, ThreadPoolExecutor
from threading import Event as ThreadingEvent
from threading import Event as ThreadingEvent, Lock
from lerobot.common.control_utils import is_headless
from lerobot.datasets import VideoEncodingManager
@@ -80,6 +80,7 @@ class HighlightStrategy(RolloutStrategy):
self._push_requested = ThreadingEvent()
self._push_executor: ThreadPoolExecutor | None = None
self._pending_push: Future | None = None
self._episode_lock = Lock()
def setup(self, ctx: RolloutContext) -> None:
"""Initialise the inference engine, ring buffer, and keyboard listener."""
@@ -168,7 +169,8 @@ class HighlightStrategy(RolloutStrategy):
self._recording_live.set()
else:
dataset.add_frame(frame)
dataset.save_episode()
with self._episode_lock:
dataset.save_episode()
logger.info("Episode saved (total: %d)", dataset.num_episodes)
log_say(
f"Episode {dataset.num_episodes} saved",
@@ -198,7 +200,7 @@ class HighlightStrategy(RolloutStrategy):
logger.info("Highlight control loop ended")
if self._recording_live.is_set():
logger.info("Saving in-progress live episode")
with contextlib.suppress(Exception):
with contextlib.suppress(Exception), self._episode_lock:
dataset.save_episode()
def teardown(self, ctx: RolloutContext) -> None:
@@ -268,12 +270,13 @@ class HighlightStrategy(RolloutStrategy):
def _push():
try:
if safe_push_to_hub(
dataset,
tags=cfg.dataset.tags if cfg.dataset else None,
private=cfg.dataset.private if cfg.dataset else False,
):
logger.info("Background push to hub complete")
with self._episode_lock:
if safe_push_to_hub(
dataset,
tags=cfg.dataset.tags if cfg.dataset else None,
private=cfg.dataset.private if cfg.dataset else False,
):
logger.info("Background push to hub complete")
except Exception as e:
logger.error("Background push failed: %s", e)
+1
View File
@@ -394,6 +394,7 @@ def record(
"Dataset names starting with 'eval_' are reserved for policy evaluation. "
"lerobot-record is for data collection only. Use lerobot-rollout for policy deployment."
)
cfg.dataset.stamp_repo_id()
dataset = LeRobotDataset.create(
cfg.dataset.repo_id,
cfg.dataset.fps,