mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 04:30:10 +00:00
chore(rollout): multiple minor improvements
This commit is contained in:
@@ -69,7 +69,12 @@ class DatasetRecordConfig:
|
||||
# Lower values reduce CPU usage, maps to 'lp' (via svtav1-params) for libsvtav1 and 'threads' for h264/hevc..
|
||||
encoder_threads: int | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
def stamp_repo_id(self) -> None:
|
||||
"""Append a date-time tag to ``repo_id`` so each recording session gets a unique name.
|
||||
|
||||
Must be called explicitly at dataset *creation* time — not on resume,
|
||||
where the existing ``repo_id`` (already stamped) must be preserved.
|
||||
"""
|
||||
if self.repo_id:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
self.repo_id = f"{self.repo_id}_{timestamp}"
|
||||
|
||||
@@ -46,17 +46,25 @@ from .inference import (
|
||||
SyncInferenceEngine,
|
||||
create_inference_engine,
|
||||
)
|
||||
from .ring_buffer import RolloutRingBuffer
|
||||
from .robot_wrapper import ThreadSafeRobot
|
||||
from .strategies import RolloutStrategy, create_strategy
|
||||
from .strategies import (
|
||||
BaseStrategy,
|
||||
DAggerStrategy,
|
||||
HighlightStrategy,
|
||||
RolloutStrategy,
|
||||
SentryStrategy,
|
||||
create_strategy,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseStrategy",
|
||||
"BaseStrategyConfig",
|
||||
"DAggerKeyboardConfig",
|
||||
"DAggerPedalConfig",
|
||||
"DAggerStrategy",
|
||||
"DAggerStrategyConfig",
|
||||
"DatasetContext",
|
||||
"HardwareContext",
|
||||
"HighlightStrategy",
|
||||
"HighlightStrategyConfig",
|
||||
"InferenceEngine",
|
||||
"InferenceEngineConfig",
|
||||
@@ -66,14 +74,13 @@ __all__ = [
|
||||
"RTCInferenceEngine",
|
||||
"RolloutConfig",
|
||||
"RolloutContext",
|
||||
"RolloutRingBuffer",
|
||||
"RolloutStrategy",
|
||||
"RolloutStrategyConfig",
|
||||
"RuntimeContext",
|
||||
"SentryStrategy",
|
||||
"SentryStrategyConfig",
|
||||
"SyncInferenceConfig",
|
||||
"SyncInferenceEngine",
|
||||
"ThreadSafeRobot",
|
||||
"build_rollout_context",
|
||||
"create_inference_engine",
|
||||
"create_strategy",
|
||||
|
||||
@@ -89,8 +89,8 @@ class HighlightStrategyConfig(RolloutStrategyConfig):
|
||||
again.
|
||||
"""
|
||||
|
||||
ring_buffer_seconds: float = 30.0
|
||||
ring_buffer_max_memory_mb: float = 2048.0
|
||||
ring_buffer_seconds: float = 10.0
|
||||
ring_buffer_max_memory_mb: float = 1024.0
|
||||
save_key: str = "s"
|
||||
push_key: str = "h"
|
||||
|
||||
|
||||
@@ -346,6 +346,12 @@ def build_rollout_context(
|
||||
"names": None,
|
||||
}
|
||||
|
||||
if not cfg.dataset.repo_id.startswith("rollout_"):
|
||||
raise ValueError(
|
||||
"Dataset names for rollout must start with 'rollout_'. "
|
||||
"Use --dataset.repo_id=rollout_<name> for policy deployment datasets."
|
||||
)
|
||||
cfg.dataset.stamp_repo_id()
|
||||
dataset = LeRobotDataset.create(
|
||||
cfg.dataset.repo_id,
|
||||
cfg.dataset.fps,
|
||||
|
||||
@@ -267,7 +267,9 @@ def send_next_action(
|
||||
if interp is None:
|
||||
return None
|
||||
|
||||
action_dict = {k: interp[i].item() for i, k in enumerate(ordered_keys) if i < len(interp)}
|
||||
if len(interp) != len(ordered_keys):
|
||||
raise ValueError(f"Interpolated tensor length ({len(interp)}) != action keys ({len(ordered_keys)})")
|
||||
action_dict = {k: interp[i].item() for i, k in enumerate(ordered_keys)}
|
||||
processed = ctx.processors.robot_action_processor((action_dict, obs_raw))
|
||||
ctx.hardware.robot_wrapper.send_action(processed)
|
||||
return action_dict
|
||||
|
||||
@@ -693,6 +693,7 @@ class DAggerStrategy(RolloutStrategy):
|
||||
}
|
||||
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
|
||||
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
|
||||
# Consider also a method that moves the robot to the teleop smoothly (similar to what we do at HW shutdown).
|
||||
# _teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
|
||||
|
||||
elif new_phase == DAggerPhase.CORRECTING:
|
||||
|
||||
@@ -22,7 +22,7 @@ import os
|
||||
import sys
|
||||
import time
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from threading import Event as ThreadingEvent
|
||||
from threading import Event as ThreadingEvent, Lock
|
||||
|
||||
from lerobot.common.control_utils import is_headless
|
||||
from lerobot.datasets import VideoEncodingManager
|
||||
@@ -80,6 +80,7 @@ class HighlightStrategy(RolloutStrategy):
|
||||
self._push_requested = ThreadingEvent()
|
||||
self._push_executor: ThreadPoolExecutor | None = None
|
||||
self._pending_push: Future | None = None
|
||||
self._episode_lock = Lock()
|
||||
|
||||
def setup(self, ctx: RolloutContext) -> None:
|
||||
"""Initialise the inference engine, ring buffer, and keyboard listener."""
|
||||
@@ -168,7 +169,8 @@ class HighlightStrategy(RolloutStrategy):
|
||||
self._recording_live.set()
|
||||
else:
|
||||
dataset.add_frame(frame)
|
||||
dataset.save_episode()
|
||||
with self._episode_lock:
|
||||
dataset.save_episode()
|
||||
logger.info("Episode saved (total: %d)", dataset.num_episodes)
|
||||
log_say(
|
||||
f"Episode {dataset.num_episodes} saved",
|
||||
@@ -198,7 +200,7 @@ class HighlightStrategy(RolloutStrategy):
|
||||
logger.info("Highlight control loop ended")
|
||||
if self._recording_live.is_set():
|
||||
logger.info("Saving in-progress live episode")
|
||||
with contextlib.suppress(Exception):
|
||||
with contextlib.suppress(Exception), self._episode_lock:
|
||||
dataset.save_episode()
|
||||
|
||||
def teardown(self, ctx: RolloutContext) -> None:
|
||||
@@ -268,12 +270,13 @@ class HighlightStrategy(RolloutStrategy):
|
||||
|
||||
def _push():
|
||||
try:
|
||||
if safe_push_to_hub(
|
||||
dataset,
|
||||
tags=cfg.dataset.tags if cfg.dataset else None,
|
||||
private=cfg.dataset.private if cfg.dataset else False,
|
||||
):
|
||||
logger.info("Background push to hub complete")
|
||||
with self._episode_lock:
|
||||
if safe_push_to_hub(
|
||||
dataset,
|
||||
tags=cfg.dataset.tags if cfg.dataset else None,
|
||||
private=cfg.dataset.private if cfg.dataset else False,
|
||||
):
|
||||
logger.info("Background push to hub complete")
|
||||
except Exception as e:
|
||||
logger.error("Background push failed: %s", e)
|
||||
|
||||
|
||||
@@ -394,6 +394,7 @@ def record(
|
||||
"Dataset names starting with 'eval_' are reserved for policy evaluation. "
|
||||
"lerobot-record is for data collection only. Use lerobot-rollout for policy deployment."
|
||||
)
|
||||
cfg.dataset.stamp_repo_id()
|
||||
dataset = LeRobotDataset.create(
|
||||
cfg.dataset.repo_id,
|
||||
cfg.dataset.fps,
|
||||
|
||||
Reference in New Issue
Block a user