mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 19:49:49 +00:00
chore(rollout): multiple minor improvements
This commit is contained in:
@@ -69,7 +69,12 @@ class DatasetRecordConfig:
|
|||||||
# Lower values reduce CPU usage, maps to 'lp' (via svtav1-params) for libsvtav1 and 'threads' for h264/hevc..
|
# Lower values reduce CPU usage, maps to 'lp' (via svtav1-params) for libsvtav1 and 'threads' for h264/hevc..
|
||||||
encoder_threads: int | None = None
|
encoder_threads: int | None = None
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def stamp_repo_id(self) -> None:
|
||||||
|
"""Append a date-time tag to ``repo_id`` so each recording session gets a unique name.
|
||||||
|
|
||||||
|
Must be called explicitly at dataset *creation* time — not on resume,
|
||||||
|
where the existing ``repo_id`` (already stamped) must be preserved.
|
||||||
|
"""
|
||||||
if self.repo_id:
|
if self.repo_id:
|
||||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
self.repo_id = f"{self.repo_id}_{timestamp}"
|
self.repo_id = f"{self.repo_id}_{timestamp}"
|
||||||
|
|||||||
@@ -46,17 +46,25 @@ from .inference import (
|
|||||||
SyncInferenceEngine,
|
SyncInferenceEngine,
|
||||||
create_inference_engine,
|
create_inference_engine,
|
||||||
)
|
)
|
||||||
from .ring_buffer import RolloutRingBuffer
|
from .strategies import (
|
||||||
from .robot_wrapper import ThreadSafeRobot
|
BaseStrategy,
|
||||||
from .strategies import RolloutStrategy, create_strategy
|
DAggerStrategy,
|
||||||
|
HighlightStrategy,
|
||||||
|
RolloutStrategy,
|
||||||
|
SentryStrategy,
|
||||||
|
create_strategy,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"BaseStrategy",
|
||||||
"BaseStrategyConfig",
|
"BaseStrategyConfig",
|
||||||
"DAggerKeyboardConfig",
|
"DAggerKeyboardConfig",
|
||||||
"DAggerPedalConfig",
|
"DAggerPedalConfig",
|
||||||
|
"DAggerStrategy",
|
||||||
"DAggerStrategyConfig",
|
"DAggerStrategyConfig",
|
||||||
"DatasetContext",
|
"DatasetContext",
|
||||||
"HardwareContext",
|
"HardwareContext",
|
||||||
|
"HighlightStrategy",
|
||||||
"HighlightStrategyConfig",
|
"HighlightStrategyConfig",
|
||||||
"InferenceEngine",
|
"InferenceEngine",
|
||||||
"InferenceEngineConfig",
|
"InferenceEngineConfig",
|
||||||
@@ -66,14 +74,13 @@ __all__ = [
|
|||||||
"RTCInferenceEngine",
|
"RTCInferenceEngine",
|
||||||
"RolloutConfig",
|
"RolloutConfig",
|
||||||
"RolloutContext",
|
"RolloutContext",
|
||||||
"RolloutRingBuffer",
|
|
||||||
"RolloutStrategy",
|
"RolloutStrategy",
|
||||||
"RolloutStrategyConfig",
|
"RolloutStrategyConfig",
|
||||||
"RuntimeContext",
|
"RuntimeContext",
|
||||||
|
"SentryStrategy",
|
||||||
"SentryStrategyConfig",
|
"SentryStrategyConfig",
|
||||||
"SyncInferenceConfig",
|
"SyncInferenceConfig",
|
||||||
"SyncInferenceEngine",
|
"SyncInferenceEngine",
|
||||||
"ThreadSafeRobot",
|
|
||||||
"build_rollout_context",
|
"build_rollout_context",
|
||||||
"create_inference_engine",
|
"create_inference_engine",
|
||||||
"create_strategy",
|
"create_strategy",
|
||||||
|
|||||||
@@ -89,8 +89,8 @@ class HighlightStrategyConfig(RolloutStrategyConfig):
|
|||||||
again.
|
again.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
ring_buffer_seconds: float = 30.0
|
ring_buffer_seconds: float = 10.0
|
||||||
ring_buffer_max_memory_mb: float = 2048.0
|
ring_buffer_max_memory_mb: float = 1024.0
|
||||||
save_key: str = "s"
|
save_key: str = "s"
|
||||||
push_key: str = "h"
|
push_key: str = "h"
|
||||||
|
|
||||||
|
|||||||
@@ -346,6 +346,12 @@ def build_rollout_context(
|
|||||||
"names": None,
|
"names": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if not cfg.dataset.repo_id.startswith("rollout_"):
|
||||||
|
raise ValueError(
|
||||||
|
"Dataset names for rollout must start with 'rollout_'. "
|
||||||
|
"Use --dataset.repo_id=rollout_<name> for policy deployment datasets."
|
||||||
|
)
|
||||||
|
cfg.dataset.stamp_repo_id()
|
||||||
dataset = LeRobotDataset.create(
|
dataset = LeRobotDataset.create(
|
||||||
cfg.dataset.repo_id,
|
cfg.dataset.repo_id,
|
||||||
cfg.dataset.fps,
|
cfg.dataset.fps,
|
||||||
|
|||||||
@@ -267,7 +267,9 @@ def send_next_action(
|
|||||||
if interp is None:
|
if interp is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
action_dict = {k: interp[i].item() for i, k in enumerate(ordered_keys) if i < len(interp)}
|
if len(interp) != len(ordered_keys):
|
||||||
|
raise ValueError(f"Interpolated tensor length ({len(interp)}) != action keys ({len(ordered_keys)})")
|
||||||
|
action_dict = {k: interp[i].item() for i, k in enumerate(ordered_keys)}
|
||||||
processed = ctx.processors.robot_action_processor((action_dict, obs_raw))
|
processed = ctx.processors.robot_action_processor((action_dict, obs_raw))
|
||||||
ctx.hardware.robot_wrapper.send_action(processed)
|
ctx.hardware.robot_wrapper.send_action(processed)
|
||||||
return action_dict
|
return action_dict
|
||||||
|
|||||||
@@ -693,6 +693,7 @@ class DAggerStrategy(RolloutStrategy):
|
|||||||
}
|
}
|
||||||
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
|
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
|
||||||
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
|
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
|
||||||
|
# Consider also a method that moves the robot to the teleop smoothly (similar to what we do at HW shutdown).
|
||||||
# _teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
|
# _teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
|
||||||
|
|
||||||
elif new_phase == DAggerPhase.CORRECTING:
|
elif new_phase == DAggerPhase.CORRECTING:
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from concurrent.futures import Future, ThreadPoolExecutor
|
from concurrent.futures import Future, ThreadPoolExecutor
|
||||||
from threading import Event as ThreadingEvent
|
from threading import Event as ThreadingEvent, Lock
|
||||||
|
|
||||||
from lerobot.common.control_utils import is_headless
|
from lerobot.common.control_utils import is_headless
|
||||||
from lerobot.datasets import VideoEncodingManager
|
from lerobot.datasets import VideoEncodingManager
|
||||||
@@ -80,6 +80,7 @@ class HighlightStrategy(RolloutStrategy):
|
|||||||
self._push_requested = ThreadingEvent()
|
self._push_requested = ThreadingEvent()
|
||||||
self._push_executor: ThreadPoolExecutor | None = None
|
self._push_executor: ThreadPoolExecutor | None = None
|
||||||
self._pending_push: Future | None = None
|
self._pending_push: Future | None = None
|
||||||
|
self._episode_lock = Lock()
|
||||||
|
|
||||||
def setup(self, ctx: RolloutContext) -> None:
|
def setup(self, ctx: RolloutContext) -> None:
|
||||||
"""Initialise the inference engine, ring buffer, and keyboard listener."""
|
"""Initialise the inference engine, ring buffer, and keyboard listener."""
|
||||||
@@ -168,7 +169,8 @@ class HighlightStrategy(RolloutStrategy):
|
|||||||
self._recording_live.set()
|
self._recording_live.set()
|
||||||
else:
|
else:
|
||||||
dataset.add_frame(frame)
|
dataset.add_frame(frame)
|
||||||
dataset.save_episode()
|
with self._episode_lock:
|
||||||
|
dataset.save_episode()
|
||||||
logger.info("Episode saved (total: %d)", dataset.num_episodes)
|
logger.info("Episode saved (total: %d)", dataset.num_episodes)
|
||||||
log_say(
|
log_say(
|
||||||
f"Episode {dataset.num_episodes} saved",
|
f"Episode {dataset.num_episodes} saved",
|
||||||
@@ -198,7 +200,7 @@ class HighlightStrategy(RolloutStrategy):
|
|||||||
logger.info("Highlight control loop ended")
|
logger.info("Highlight control loop ended")
|
||||||
if self._recording_live.is_set():
|
if self._recording_live.is_set():
|
||||||
logger.info("Saving in-progress live episode")
|
logger.info("Saving in-progress live episode")
|
||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception), self._episode_lock:
|
||||||
dataset.save_episode()
|
dataset.save_episode()
|
||||||
|
|
||||||
def teardown(self, ctx: RolloutContext) -> None:
|
def teardown(self, ctx: RolloutContext) -> None:
|
||||||
@@ -268,12 +270,13 @@ class HighlightStrategy(RolloutStrategy):
|
|||||||
|
|
||||||
def _push():
|
def _push():
|
||||||
try:
|
try:
|
||||||
if safe_push_to_hub(
|
with self._episode_lock:
|
||||||
dataset,
|
if safe_push_to_hub(
|
||||||
tags=cfg.dataset.tags if cfg.dataset else None,
|
dataset,
|
||||||
private=cfg.dataset.private if cfg.dataset else False,
|
tags=cfg.dataset.tags if cfg.dataset else None,
|
||||||
):
|
private=cfg.dataset.private if cfg.dataset else False,
|
||||||
logger.info("Background push to hub complete")
|
):
|
||||||
|
logger.info("Background push to hub complete")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Background push failed: %s", e)
|
logger.error("Background push failed: %s", e)
|
||||||
|
|
||||||
|
|||||||
@@ -394,6 +394,7 @@ def record(
|
|||||||
"Dataset names starting with 'eval_' are reserved for policy evaluation. "
|
"Dataset names starting with 'eval_' are reserved for policy evaluation. "
|
||||||
"lerobot-record is for data collection only. Use lerobot-rollout for policy deployment."
|
"lerobot-record is for data collection only. Use lerobot-rollout for policy deployment."
|
||||||
)
|
)
|
||||||
|
cfg.dataset.stamp_repo_id()
|
||||||
dataset = LeRobotDataset.create(
|
dataset = LeRobotDataset.create(
|
||||||
cfg.dataset.repo_id,
|
cfg.dataset.repo_id,
|
||||||
cfg.dataset.fps,
|
cfg.dataset.fps,
|
||||||
|
|||||||
+18
-11
@@ -86,7 +86,7 @@ def test_dagger_config_defaults():
|
|||||||
|
|
||||||
|
|
||||||
def test_inference_config_types():
|
def test_inference_config_types():
|
||||||
from lerobot.rollout.inference import RTCInferenceConfig, SyncInferenceConfig
|
from lerobot.rollout import RTCInferenceConfig, SyncInferenceConfig
|
||||||
|
|
||||||
assert SyncInferenceConfig().type == "sync"
|
assert SyncInferenceConfig().type == "sync"
|
||||||
|
|
||||||
@@ -110,7 +110,7 @@ def test_sentry_config_defaults():
|
|||||||
|
|
||||||
|
|
||||||
def test_ring_buffer_append_and_eviction():
|
def test_ring_buffer_append_and_eviction():
|
||||||
from lerobot.rollout import RolloutRingBuffer
|
from lerobot.rollout.ring_buffer import RolloutRingBuffer
|
||||||
|
|
||||||
buf = RolloutRingBuffer(max_seconds=0.5, max_memory_mb=100.0, fps=10.0)
|
buf = RolloutRingBuffer(max_seconds=0.5, max_memory_mb=100.0, fps=10.0)
|
||||||
# max_frames = 5
|
# max_frames = 5
|
||||||
@@ -120,7 +120,7 @@ def test_ring_buffer_append_and_eviction():
|
|||||||
|
|
||||||
|
|
||||||
def test_ring_buffer_drain():
|
def test_ring_buffer_drain():
|
||||||
from lerobot.rollout import RolloutRingBuffer
|
from lerobot.rollout.ring_buffer import RolloutRingBuffer
|
||||||
|
|
||||||
buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0)
|
buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0)
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
@@ -132,7 +132,7 @@ def test_ring_buffer_drain():
|
|||||||
|
|
||||||
|
|
||||||
def test_ring_buffer_clear():
|
def test_ring_buffer_clear():
|
||||||
from lerobot.rollout import RolloutRingBuffer
|
from lerobot.rollout.ring_buffer import RolloutRingBuffer
|
||||||
|
|
||||||
buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0)
|
buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0)
|
||||||
buf.append({"val": 1})
|
buf.append({"val": 1})
|
||||||
@@ -142,7 +142,7 @@ def test_ring_buffer_clear():
|
|||||||
|
|
||||||
|
|
||||||
def test_ring_buffer_tensor_bytes():
|
def test_ring_buffer_tensor_bytes():
|
||||||
from lerobot.rollout import RolloutRingBuffer
|
from lerobot.rollout.ring_buffer import RolloutRingBuffer
|
||||||
|
|
||||||
buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0)
|
buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0)
|
||||||
t = torch.zeros(100, dtype=torch.float32) # 400 bytes
|
t = torch.zeros(100, dtype=torch.float32) # 400 bytes
|
||||||
@@ -156,7 +156,7 @@ def test_ring_buffer_tensor_bytes():
|
|||||||
|
|
||||||
|
|
||||||
def test_thread_safe_robot_delegates():
|
def test_thread_safe_robot_delegates():
|
||||||
from lerobot.rollout import ThreadSafeRobot
|
from lerobot.rollout.robot_wrapper import ThreadSafeRobot
|
||||||
from tests.mocks.mock_robot import MockRobot, MockRobotConfig
|
from tests.mocks.mock_robot import MockRobot, MockRobotConfig
|
||||||
|
|
||||||
robot = MockRobot(MockRobotConfig(n_motors=3))
|
robot = MockRobot(MockRobotConfig(n_motors=3))
|
||||||
@@ -176,7 +176,7 @@ def test_thread_safe_robot_delegates():
|
|||||||
|
|
||||||
|
|
||||||
def test_thread_safe_robot_properties():
|
def test_thread_safe_robot_properties():
|
||||||
from lerobot.rollout import ThreadSafeRobot
|
from lerobot.rollout.robot_wrapper import ThreadSafeRobot
|
||||||
from tests.mocks.mock_robot import MockRobot, MockRobotConfig
|
from tests.mocks.mock_robot import MockRobot, MockRobotConfig
|
||||||
|
|
||||||
robot = MockRobot(MockRobotConfig(n_motors=3))
|
robot = MockRobot(MockRobotConfig(n_motors=3))
|
||||||
@@ -198,8 +198,15 @@ def test_thread_safe_robot_properties():
|
|||||||
|
|
||||||
|
|
||||||
def test_create_strategy_dispatches():
|
def test_create_strategy_dispatches():
|
||||||
from lerobot.rollout import BaseStrategyConfig, DAggerStrategyConfig, SentryStrategyConfig
|
from lerobot.rollout import (
|
||||||
from lerobot.rollout.strategies import BaseStrategy, DAggerStrategy, SentryStrategy, create_strategy
|
BaseStrategy,
|
||||||
|
BaseStrategyConfig,
|
||||||
|
DAggerStrategy,
|
||||||
|
DAggerStrategyConfig,
|
||||||
|
SentryStrategy,
|
||||||
|
SentryStrategyConfig,
|
||||||
|
create_strategy,
|
||||||
|
)
|
||||||
|
|
||||||
assert isinstance(create_strategy(BaseStrategyConfig()), BaseStrategy)
|
assert isinstance(create_strategy(BaseStrategyConfig()), BaseStrategy)
|
||||||
assert isinstance(create_strategy(SentryStrategyConfig()), SentryStrategy)
|
assert isinstance(create_strategy(SentryStrategyConfig()), SentryStrategy)
|
||||||
@@ -207,7 +214,7 @@ def test_create_strategy_dispatches():
|
|||||||
|
|
||||||
|
|
||||||
def test_create_strategy_unknown_raises():
|
def test_create_strategy_unknown_raises():
|
||||||
from lerobot.rollout.strategies import create_strategy
|
from lerobot.rollout import create_strategy
|
||||||
|
|
||||||
cfg = MagicMock()
|
cfg = MagicMock()
|
||||||
cfg.type = "bogus"
|
cfg.type = "bogus"
|
||||||
@@ -221,7 +228,7 @@ def test_create_strategy_unknown_raises():
|
|||||||
|
|
||||||
|
|
||||||
def test_create_inference_engine_sync():
|
def test_create_inference_engine_sync():
|
||||||
from lerobot.rollout.inference import SyncInferenceConfig, SyncInferenceEngine, create_inference_engine
|
from lerobot.rollout import SyncInferenceConfig, SyncInferenceEngine, create_inference_engine
|
||||||
|
|
||||||
engine = create_inference_engine(
|
engine = create_inference_engine(
|
||||||
SyncInferenceConfig(),
|
SyncInferenceConfig(),
|
||||||
|
|||||||
Reference in New Issue
Block a user