diff --git a/examples/phone_to_so100/evaluate.py b/examples/phone_to_so100/evaluate.py index 62b256e42..e859123d0 100644 --- a/examples/phone_to_so100/evaluate.py +++ b/examples/phone_to_so100/evaluate.py @@ -222,7 +222,6 @@ def main(): # Save episode dataset.save_episode() - episode_idx += 1 finally: # Clean up log_say("Stop recording") diff --git a/examples/so100_to_so100_EE/evaluate.py b/examples/so100_to_so100_EE/evaluate.py index c11b9556e..63def68d0 100644 --- a/examples/so100_to_so100_EE/evaluate.py +++ b/examples/so100_to_so100_EE/evaluate.py @@ -222,7 +222,6 @@ def main(): # Save episode dataset.save_episode() - episode_idx += 1 finally: # Clean up log_say("Stop recording") diff --git a/src/lerobot/rollout/context.py b/src/lerobot/rollout/context.py index dce765a88..48c60b7fd 100644 --- a/src/lerobot/rollout/context.py +++ b/src/lerobot/rollout/context.py @@ -42,12 +42,32 @@ from lerobot.robots import Robot, make_robot_from_config from lerobot.teleoperators import Teleoperator, make_teleoperator_from_config from lerobot.utils.feature_utils import combine_feature_dicts, hw_to_dataset_features -from .configs import BaseStrategyConfig, RolloutConfig +from .configs import BaseStrategyConfig, DAggerStrategyConfig, RolloutConfig from .robot_wrapper import ThreadSafeRobot logger = logging.getLogger(__name__) +def _resolve_action_key_order( + policy_action_names: list[str] | None, dataset_action_names: list[str] +) -> list[str]: + """Choose action name ordering for mapping policy tensor outputs to robot action dicts.""" + if not policy_action_names: + return dataset_action_names + policy_action_names = list(policy_action_names) + if len(policy_action_names) != len(dataset_action_names): + logger.warning( + "policy.action_feature_names length (%d) != dataset action dim (%d); using dataset order", + len(policy_action_names), + len(dataset_action_names), + ) + return dataset_action_names + if set(dataset_action_names) != set(policy_action_names): + logger.warning("policy.action_feature_names keys don't match dataset; using dataset order") + return dataset_action_names + return policy_action_names + + @dataclass class RolloutContext: """Bundle of shared resources passed to every rollout strategy. @@ -69,6 +89,7 @@ class RolloutContext: shutdown_event: Event = field(default_factory=Event) dataset_features: dict = field(default_factory=dict) action_keys: list[str] = field(default_factory=list) + ordered_action_keys: list[str] = field(default_factory=list) hw_features: dict = field(default_factory=dict) @@ -92,26 +113,37 @@ def build_rollout_context(cfg: RolloutConfig, shutdown_event: Event) -> RolloutC teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors() # --- Policy --- + # Use cfg.policy directly (already loaded in RolloutConfig.__post_init__) + # instead of reloading from disk. + policy_config = cfg.policy use_rtc = cfg.rtc.enabled - policy_class = get_policy_class(cfg.policy.type) - policy_config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path) + policy_class = get_policy_class(policy_config.type) + + # Reload config from pretrained path for full model parameters + full_config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path) + # Merge any CLI overrides from cfg.policy into full_config + for attr in ("device", "use_amp"): + if hasattr(cfg.policy, attr) and hasattr(full_config, attr): + cli_val = getattr(cfg.policy, attr) + if cli_val is not None: + setattr(full_config, attr, cli_val) # Set compile_model for pi0/pi05 - if hasattr(policy_config, "compile_model"): - policy_config.compile_model = cfg.use_torch_compile + if hasattr(full_config, "compile_model"): + full_config.compile_model = cfg.use_torch_compile # Handle PEFT models - if policy_config.use_peft: + if full_config.use_peft: from peft import PeftConfig, PeftModel peft_path = cfg.policy.pretrained_path peft_config = PeftConfig.from_pretrained(peft_path) policy = policy_class.from_pretrained( - pretrained_name_or_path=peft_config.base_model_name_or_path, config=policy_config + pretrained_name_or_path=peft_config.base_model_name_or_path, config=full_config ) policy = PeftModel.from_pretrained(policy, peft_path, config=peft_config) else: - policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=policy_config) + policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=full_config) # Enable RTC on the policy if use_rtc: @@ -136,10 +168,13 @@ def build_rollout_context(cfg: RolloutConfig, shutdown_event: Event) -> RolloutC except Exception as e: logger.warning("Failed to apply torch.compile: %s", e) - # --- Observation features (filter to .pos joints + camera streams) --- + # --- Observation features --- + # Hardware-level features: camera features are tuples (H, W, C), state + # features are the ``float`` type. This is the canonical pattern used + # throughout the codebase (see feature_utils.py:hw_to_dataset_features). all_obs_features = robot.observation_features observation_features_hw = { - k: v for k, v in all_obs_features.items() if k.endswith(".pos") or isinstance(v, tuple) + k: v for k, v in all_obs_features.items() if v is float or isinstance(v, tuple) } action_features_hw = {k: v for k, v in robot.action_features.items() if k.endswith(".pos")} @@ -163,6 +198,13 @@ def build_rollout_context(cfg: RolloutConfig, shutdown_event: Event) -> RolloutC # Action keys action_keys = [k for k in robot.action_features if k.endswith(".pos")] + # Ordered action keys (reconcile policy vs dataset ordering) + policy_action_names = getattr(policy_config, "action_feature_names", None) + ordered_action_keys = _resolve_action_key_order( + list(policy_action_names) if policy_action_names else None, + action_keys, + ) + # --- Dataset --- dataset = None if cfg.dataset is not None and not isinstance(cfg.strategy, BaseStrategyConfig): @@ -180,6 +222,14 @@ def build_rollout_context(cfg: RolloutConfig, shutdown_event: Event) -> RolloutC * len(robot.cameras if hasattr(robot, "cameras") else []), ) else: + # Add intervention column for DAgger strategy + if isinstance(cfg.strategy, DAggerStrategyConfig): + dataset_features["intervention"] = { + "dtype": "int64", + "shape": (1,), + "names": None, + } + dataset = LeRobotDataset.create( cfg.dataset.repo_id, cfg.dataset.fps, @@ -206,11 +256,11 @@ def build_rollout_context(cfg: RolloutConfig, shutdown_event: Event) -> RolloutC ) preprocessor, postprocessor = make_pre_post_processors( - policy_cfg=cfg.policy, + policy_cfg=policy_config, pretrained_path=cfg.policy.pretrained_path, dataset_stats=dataset_stats, preprocessor_overrides={ - "device_processor": {"device": cfg.device or cfg.policy.device}, + "device_processor": {"device": cfg.device or getattr(policy_config, "device", "cpu")}, "rename_observations_processor": {"rename_map": cfg.dataset.rename_map if cfg.dataset else {}}, }, ) @@ -230,5 +280,6 @@ def build_rollout_context(cfg: RolloutConfig, shutdown_event: Event) -> RolloutC shutdown_event=shutdown_event, dataset_features=dataset_features, action_keys=action_keys, + ordered_action_keys=ordered_action_keys, hw_features=hw_features, ) diff --git a/src/lerobot/rollout/inference.py b/src/lerobot/rollout/inference.py index 2eeac10e9..c5e94bf3f 100644 --- a/src/lerobot/rollout/inference.py +++ b/src/lerobot/rollout/inference.py @@ -31,10 +31,10 @@ from typing import Any import torch -from lerobot.common.control_utils import prepare_observation_for_inference from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.policies.rtc import ActionInterpolator, ActionQueue, LatencyTracker +from lerobot.policies.rtc import ActionQueue, LatencyTracker from lerobot.policies.rtc.configuration_rtc import RTCConfig +from lerobot.policies.utils import prepare_observation_for_inference from lerobot.processor import ( NormalizerProcessorStep, PolicyProcessorPipeline, @@ -93,26 +93,6 @@ def _normalize_prev_actions_length(prev_actions: torch.Tensor, target_steps: int return padded -def _resolve_action_key_order( - policy_action_names: list[str] | None, dataset_action_names: list[str] -) -> list[str]: - """Choose action name ordering for mapping policy tensor outputs to robot action dicts.""" - if not policy_action_names: - return dataset_action_names - policy_action_names = list(policy_action_names) - if len(policy_action_names) != len(dataset_action_names): - logger.warning( - "policy.action_feature_names length (%d) != dataset action dim (%d); using dataset order", - len(policy_action_names), - len(dataset_action_names), - ) - return dataset_action_names - if set(dataset_action_names) != set(policy_action_names): - logger.warning("policy.action_feature_names keys don't match dataset; using dataset order") - return dataset_action_names - return policy_action_names - - # --------------------------------------------------------------------------- # InferenceEngine # --------------------------------------------------------------------------- @@ -143,12 +123,13 @@ class InferenceEngine: Control loop frequency. device: Torch device string. - interpolator: - Action interpolator (used only in RTC mode for the actor loop). use_torch_compile: Whether torch.compile warmup is needed. compile_warmup_inferences: Number of warmup inferences before live rollout. + rtc_queue_threshold: + Maximum RTC action queue size before the background thread + pauses generation. Prevents unbounded queue growth. """ def __init__( @@ -163,9 +144,9 @@ class InferenceEngine: task: str, fps: float, device: str | None, - interpolator: ActionInterpolator | None = None, use_torch_compile: bool = False, compile_warmup_inferences: int = 2, + rtc_queue_threshold: int = 30, ) -> None: self._policy = policy self._preprocessor = preprocessor @@ -177,9 +158,9 @@ class InferenceEngine: self._task = task self._fps = fps self._device = device or "cpu" - self._interpolator = interpolator self._use_torch_compile = use_torch_compile self._compile_warmup_inferences = compile_warmup_inferences + self._rtc_queue_threshold = rtc_queue_threshold # RTC state self._use_rtc = rtc_config.enabled @@ -270,8 +251,6 @@ class InferenceEngine: self._postprocessor.reset() if self._use_rtc: self._action_queue = ActionQueue(self._rtc_config) - if self._interpolator is not None: - self._interpolator.reset() # ------------------------------------------------------------------ # Sync inference @@ -329,8 +308,7 @@ class InferenceEngine: try: latency_tracker = LatencyTracker() time_per_chunk = 1.0 / self._fps - threshold = 30 - policy_device = self._policy.config.device + policy_device = torch.device(self._device) warmup_required = max(1, self._compile_warmup_inferences) if self._use_torch_compile else 0 inference_count = 0 @@ -347,7 +325,7 @@ class InferenceEngine: time.sleep(0.01) continue - if queue.qsize() <= threshold: + if queue.qsize() <= self._rtc_queue_threshold: try: current_time = time.perf_counter() idx_before = queue.get_action_index() @@ -356,35 +334,29 @@ class InferenceEngine: latency = latency_tracker.max() delay = math.ceil(latency / time_per_chunk) if latency else 0 - # Build observation batch + # Build observation batch using the same pipeline as sync inference obs_batch = build_dataset_frame(self._hw_features, obs, prefix="observation") - for name in obs_batch: - obs_batch[name] = torch.from_numpy(obs_batch[name]) - if "image" in name: - obs_batch[name] = obs_batch[name].float() / 255 - obs_batch[name] = obs_batch[name].permute(2, 0, 1).contiguous() - obs_batch[name] = obs_batch[name].unsqueeze(0).to(policy_device) - + obs_batch = prepare_observation_for_inference( + obs_batch, policy_device, self._task, self._robot.robot_type + ) + # predict_action_chunk expects batched task format obs_batch["task"] = [self._task] - obs_batch["robot_type"] = self._obs_holder.get("robot_type", "unknown") preprocessed = self._preprocessor(obs_batch) # Re-anchor leftover for relative-action policies - if ( - prev_actions is not None - and self._relative_step is not None - and OBS_STATE in obs_batch - ): - prev_abs = queue.get_processed_left_over() - if prev_abs is not None and prev_abs.numel() > 0: - prev_actions = _reanchor_relative_rtc_prefix( - prev_actions_absolute=prev_abs, - current_state=obs_batch[OBS_STATE], - relative_step=self._relative_step, - normalizer_step=self._normalizer_step, - policy_device=policy_device, - ) + if prev_actions is not None and self._relative_step is not None: + state_tensor = preprocessed.get(OBS_STATE) + if state_tensor is not None: + prev_abs = queue.get_processed_left_over() + if prev_abs is not None and prev_abs.numel() > 0: + prev_actions = _reanchor_relative_rtc_prefix( + prev_actions_absolute=prev_abs, + current_state=state_tensor, + relative_step=self._relative_step, + normalizer_step=self._normalizer_step, + policy_device=policy_device, + ) if prev_actions is not None: prev_actions = _normalize_prev_actions_length( diff --git a/src/lerobot/rollout/strategies/__init__.py b/src/lerobot/rollout/strategies/__init__.py index bdb3b5952..3c5eee83a 100644 --- a/src/lerobot/rollout/strategies/__init__.py +++ b/src/lerobot/rollout/strategies/__init__.py @@ -12,16 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Rollout strategy ABC and factory.""" +"""Rollout strategy ABC, factory, and shared inference helper.""" from __future__ import annotations import abc from typing import TYPE_CHECKING +import torch + +from lerobot.policies.rtc import ActionInterpolator +from lerobot.policies.utils import make_robot_action +from lerobot.utils.constants import OBS_STR +from lerobot.utils.feature_utils import build_dataset_frame + if TYPE_CHECKING: from lerobot.rollout.configs import RolloutStrategyConfig from lerobot.rollout.context import RolloutContext + from lerobot.rollout.inference import InferenceEngine class RolloutStrategy(abc.ABC): @@ -48,6 +56,77 @@ class RolloutStrategy(abc.ABC): """Cleanup: save dataset, stop threads, disconnect hardware.""" +# --------------------------------------------------------------------------- +# Shared inference helper +# --------------------------------------------------------------------------- + + +def infer_action( + engine: InferenceEngine, + obs_processed: dict, + obs_raw: dict, + ctx: RolloutContext, + interpolator: ActionInterpolator, + ordered_keys: list[str], + features: dict, +) -> dict | None: + """Run one policy inference step and send the resulting action to the robot. + + Handles both sync and RTC backends. Uses the interpolator for smooth + control at higher-than-inference rates (works with any multiplier, + including 1 where it acts as a pass-through). + + Parameters + ---------- + engine: + The inference engine (sync or RTC). + obs_processed: + Observation dict after ``robot_observation_processor``. + obs_raw: + Raw observation dict (needed by ``robot_action_processor``). + ctx: + Rollout context. + interpolator: + Action interpolator for Nx control rate. + ordered_keys: + Ordered action feature names (policy-to-robot mapping). + features: + Feature specification dict for ``build_dataset_frame`` / + ``make_robot_action``. Use ``dataset.features`` when recording, + ``ctx.dataset_features`` otherwise. + + Returns + ------- + Action dict sent to the robot, or ``None`` if no action was + available (empty RTC queue, interpolator buffer not ready). + """ + if engine.is_rtc: + if interpolator.needs_new_action(): + action_tensor = engine.consume_rtc_action() + if action_tensor is not None: + interpolator.add(action_tensor.cpu()) + else: + if interpolator.needs_new_action(): + obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR) + action_tensor = engine.get_action_sync(obs_frame) + action_dict = make_robot_action(action_tensor, features) + action_t = torch.tensor([action_dict[k] for k in ordered_keys]) + interpolator.add(action_t) + + interp = interpolator.get() + if interp is not None: + action_dict = {k: interp[i].item() for i, k in enumerate(ordered_keys) if i < len(interp)} + processed = ctx.robot_action_processor((action_dict, obs_raw)) + ctx.robot_wrapper.send_action(processed) + return action_dict + return None + + +# --------------------------------------------------------------------------- +# Strategy factory +# --------------------------------------------------------------------------- + + def create_strategy(config: RolloutStrategyConfig) -> RolloutStrategy: """Instantiate the appropriate strategy from a config object.""" from lerobot.rollout.configs import ( diff --git a/src/lerobot/rollout/strategies/base.py b/src/lerobot/rollout/strategies/base.py index 5e9699bdb..8f5d5c7d1 100644 --- a/src/lerobot/rollout/strategies/base.py +++ b/src/lerobot/rollout/strategies/base.py @@ -20,14 +20,11 @@ import logging import time from lerobot.policies.rtc import ActionInterpolator -from lerobot.policies.utils import make_robot_action -from lerobot.utils.constants import OBS_STR -from lerobot.utils.feature_utils import build_dataset_frame from lerobot.utils.robot_utils import precise_sleep from ..context import RolloutContext -from ..inference import InferenceEngine, _resolve_action_key_order -from . import RolloutStrategy +from ..inference import InferenceEngine +from . import RolloutStrategy, infer_action logger = logging.getLogger(__name__) @@ -43,9 +40,10 @@ class BaseStrategy(RolloutStrategy): def __init__(self, config): super().__init__(config) self._engine: InferenceEngine | None = None + self._interpolator: ActionInterpolator | None = None def setup(self, ctx: RolloutContext) -> None: - interpolator = ActionInterpolator(multiplier=ctx.cfg.interpolation_multiplier) + self._interpolator = ActionInterpolator(multiplier=ctx.cfg.interpolation_multiplier) self._engine = InferenceEngine( policy=ctx.policy, @@ -58,7 +56,6 @@ class BaseStrategy(RolloutStrategy): task=ctx.cfg.task, fps=ctx.cfg.fps, device=ctx.cfg.device, - interpolator=interpolator, use_torch_compile=ctx.cfg.use_torch_compile, compile_warmup_inferences=ctx.cfg.compile_warmup_inferences, ) @@ -69,16 +66,10 @@ class BaseStrategy(RolloutStrategy): engine = self._engine cfg = ctx.cfg robot = ctx.robot_wrapper - action_keys = ctx.action_keys + interpolator = self._interpolator - interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier) control_interval = interpolator.get_control_interval(cfg.fps) - - policy_action_names = getattr(cfg.policy, "action_feature_names", None) - ordered_keys = _resolve_action_key_order( - list(policy_action_names) if policy_action_names else None, - action_keys, - ) + ordered_keys = ctx.ordered_action_keys start_time = time.perf_counter() warmup_flushed = False @@ -98,34 +89,21 @@ class BaseStrategy(RolloutStrategy): if engine.is_rtc: engine.update_observation(obs_processed) - if cfg.use_torch_compile and not engine.compile_warmup_done.is_set(): - dt = time.perf_counter() - loop_start - if (sleep_t := control_interval - dt) > 0: - precise_sleep(sleep_t) - continue + # Wait for torch.compile warmup before running live inference + if cfg.use_torch_compile and not engine.compile_warmup_done.is_set(): + dt = time.perf_counter() - loop_start + if (sleep_t := control_interval - dt) > 0: + precise_sleep(sleep_t) + continue - if cfg.use_torch_compile and not warmup_flushed: - engine.reset() - interpolator.reset() - warmup_flushed = True + if cfg.use_torch_compile and not warmup_flushed: + engine.reset() + interpolator.reset() + warmup_flushed = True + if engine.is_rtc: + engine.resume() - if interpolator.needs_new_action(): - action_tensor = engine.consume_rtc_action() - if action_tensor is not None: - interpolator.add(action_tensor.cpu()) - - interp = interpolator.get() - if interp is not None: - action_dict = {k: interp[i].item() for i, k in enumerate(ordered_keys) if i < len(interp)} - processed = ctx.robot_action_processor((action_dict, obs)) - robot.send_action(processed) - - else: - obs_frame = build_dataset_frame(ctx.dataset_features, obs_processed, prefix=OBS_STR) - action_tensor = engine.get_action_sync(obs_frame) - action_dict = make_robot_action(action_tensor, ctx.dataset_features) - processed = ctx.robot_action_processor((action_dict, obs)) - robot.send_action(processed) + infer_action(engine, obs_processed, obs, ctx, interpolator, ordered_keys, ctx.dataset_features) dt = time.perf_counter() - loop_start if (sleep_t := control_interval - dt) > 0: diff --git a/src/lerobot/rollout/strategies/dagger.py b/src/lerobot/rollout/strategies/dagger.py index ca1360ceb..a6d82ec3f 100644 --- a/src/lerobot/rollout/strategies/dagger.py +++ b/src/lerobot/rollout/strategies/dagger.py @@ -29,26 +29,26 @@ Keyboard Controls: from __future__ import annotations +import contextlib import logging import time from typing import Any -import torch +import numpy as np -from lerobot.common.control_utils import is_headless, predict_action +from lerobot.common.control_utils import is_headless from lerobot.datasets import VideoEncodingManager from lerobot.policies.rtc import ActionInterpolator -from lerobot.policies.utils import make_robot_action +from lerobot.processor import RobotProcessorPipeline from lerobot.utils.constants import ACTION, OBS_STR -from lerobot.utils.device_utils import get_safe_torch_device from lerobot.utils.feature_utils import build_dataset_frame from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.utils import log_say from ..configs import DAggerStrategyConfig from ..context import RolloutContext -from ..inference import InferenceEngine, _resolve_action_key_order -from . import RolloutStrategy +from ..inference import InferenceEngine +from . import RolloutStrategy, infer_action logger = logging.getLogger(__name__) @@ -94,8 +94,19 @@ def _teleop_smooth_move_to(teleop, target_pos: dict, duration_s: float = 2.0, fp time.sleep(1 / fps) -def _reset_loop(robot, teleop, events: dict, fps: int) -> None: - """Reset period where the human repositions the environment.""" +def _reset_loop( + robot, + teleop, + events: dict, + fps: int, + teleop_action_processor: RobotProcessorPipeline, + robot_action_processor: RobotProcessorPipeline, +) -> None: + """Reset period where the human repositions the environment. + + All teleop actions flow through the processor pipelines to ensure + correct behavior for EE-space robots. + """ logger.info("RESET — press any key to enable teleoperation") events["in_reset"] = True @@ -117,8 +128,11 @@ def _reset_loop(robot, teleop, events: dict, fps: int) -> None: while not events["start_next_episode"] and not events["stop_recording"]: loop_start = time.perf_counter() + obs = robot.get_observation() action = teleop.get_action() - robot.send_action(action) + processed_teleop = teleop_action_processor((action, obs)) + robot_action_to_send = robot_action_processor((processed_teleop, obs)) + robot.send_action(robot_action_to_send) precise_sleep(1 / fps - (time.perf_counter() - loop_start)) events["in_reset"] = False @@ -251,6 +265,10 @@ class DAggerStrategy(RolloutStrategy): Supports both synchronous and RTC inference backends. All actions (policy and teleop) flow through the appropriate processor pipelines, supporting EE-space recording. + + Intervention frames are tagged with ``intervention=1`` (int64) in + the dataset to allow downstream BC training to distinguish + autonomous from human-corrected data. """ config: DAggerStrategyConfig @@ -258,11 +276,12 @@ class DAggerStrategy(RolloutStrategy): def __init__(self, config: DAggerStrategyConfig): super().__init__(config) self._engine: InferenceEngine | None = None + self._interpolator: ActionInterpolator | None = None self._listener = None self._events: dict[str, Any] = {} def setup(self, ctx: RolloutContext) -> None: - interpolator = ActionInterpolator(multiplier=ctx.cfg.interpolation_multiplier) + self._interpolator = ActionInterpolator(multiplier=ctx.cfg.interpolation_multiplier) self._engine = InferenceEngine( policy=ctx.policy, @@ -275,7 +294,6 @@ class DAggerStrategy(RolloutStrategy): task=ctx.cfg.task, fps=ctx.cfg.fps, device=ctx.cfg.device, - interpolator=interpolator, use_torch_compile=ctx.cfg.use_torch_compile, compile_warmup_inferences=ctx.cfg.compile_warmup_inferences, ) @@ -293,7 +311,6 @@ class DAggerStrategy(RolloutStrategy): logger.info("Controls: SPACE=pause, c=take control, p=resume, ->=end, <-=redo, ESC=stop") def run(self, ctx: RolloutContext) -> None: - engine = self._engine dataset = ctx.dataset events = self._events teleop = ctx.teleop @@ -317,13 +334,18 @@ class DAggerStrategy(RolloutStrategy): recorded += 1 if recorded < self.config.num_episodes and not events["stop_recording"]: - _reset_loop(ctx.robot_wrapper, teleop, events, int(ctx.cfg.fps)) + _reset_loop( + ctx.robot_wrapper, + teleop, + events, + int(ctx.cfg.fps), + ctx.teleop_action_processor, + ctx.robot_action_processor, + ) finally: - try: + with contextlib.suppress(Exception): dataset.save_episode() - except Exception: - pass def teardown(self, ctx: RolloutContext) -> None: log_say("Stop recording", self.config.play_sounds, blocking=True) @@ -360,27 +382,22 @@ class DAggerStrategy(RolloutStrategy): teleop = ctx.teleop dataset = ctx.dataset events = self._events + interpolator = self._interpolator - interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier) control_interval = interpolator.get_control_interval(cfg.fps) stream_online = bool(cfg.dataset.streaming_encoding) if cfg.dataset else False record_stride = max(1, cfg.interpolation_multiplier) - policy_action_names = getattr(cfg.policy, "action_feature_names", None) - ordered_keys = _resolve_action_key_order( - list(policy_action_names) if policy_action_names else None, - ctx.action_keys, - ) - - dataset_action_keys = list(dataset.features.get(ACTION, {}).get("names", ctx.action_keys)) + ordered_keys = ctx.ordered_action_keys + features = dataset.features engine.reset() + interpolator.reset() _teleop_disable_torque(teleop) was_paused = False waiting_for_takeover = False last_action: dict[str, Any] | None = None - robot_action: dict[str, Any] = {} frame_buffer: list[dict] = [] task_str = cfg.dataset.single_task if cfg.dataset else cfg.task @@ -444,7 +461,7 @@ class DAggerStrategy(RolloutStrategy): # --- Get observation --- obs = robot.get_observation() obs_processed = ctx.robot_observation_processor(obs) - obs_frame = build_dataset_frame(dataset.features, obs_processed, prefix=OBS_STR) + obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR) # --- CORRECTION: human teleop control --- if events["correction_active"]: @@ -452,9 +469,14 @@ class DAggerStrategy(RolloutStrategy): processed_teleop = ctx.teleop_action_processor((teleop_action, obs)) robot_action_to_send = ctx.robot_action_processor((processed_teleop, obs)) robot.send_action(robot_action_to_send) - action_frame = build_dataset_frame(dataset.features, processed_teleop, prefix=ACTION) + action_frame = build_dataset_frame(features, processed_teleop, prefix=ACTION) if record_tick % record_stride == 0: - frame = {**obs_frame, **action_frame, "task": task_str} + frame = { + **obs_frame, + **action_frame, + "task": task_str, + "intervention": np.array([1], dtype=np.int64), + } if stream_online: dataset.add_frame(frame) else: @@ -471,73 +493,40 @@ class DAggerStrategy(RolloutStrategy): if engine.is_rtc: engine.update_observation(obs_processed) - if cfg.use_torch_compile and not engine.compile_warmup_done.is_set(): - dt = time.perf_counter() - loop_start - if (sleep_t := control_interval - dt) > 0: - precise_sleep(sleep_t) - timestamp = time.perf_counter() - start_t - continue + # Wait for torch.compile warmup + if cfg.use_torch_compile and not engine.compile_warmup_done.is_set(): + dt = time.perf_counter() - loop_start + if (sleep_t := control_interval - dt) > 0: + precise_sleep(sleep_t) + timestamp = time.perf_counter() - start_t + continue - if cfg.use_torch_compile and not warmup_flushed: - engine.reset() - interpolator.reset() - warmup_flushed = True - if engine.is_rtc: - engine.resume() + if cfg.use_torch_compile and not warmup_flushed: + engine.reset() + interpolator.reset() + warmup_flushed = True + if engine.is_rtc: + engine.resume() - if interpolator.needs_new_action(): - action_tensor = engine.consume_rtc_action() - if action_tensor is not None: - interpolator.add(action_tensor.cpu()) + action_dict = infer_action( + engine, obs_processed, obs, ctx, interpolator, ordered_keys, features + ) - interp = interpolator.get() - if interp is not None: - robot_action = { - k: interp[i].item() for i, k in enumerate(ordered_keys) if i < len(interp) + if action_dict is not None: + last_action = ctx.robot_action_processor((action_dict, obs)) + action_frame = build_dataset_frame(features, action_dict, prefix=ACTION) + if record_tick % record_stride == 0: + frame = { + **obs_frame, + **action_frame, + "task": task_str, + "intervention": np.array([0], dtype=np.int64), } - processed = ctx.robot_action_processor((robot_action, obs)) - robot.send_action(processed) - last_action = processed - action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION) - if record_tick % record_stride == 0: - frame = {**obs_frame, **action_frame, "task": task_str} - if stream_online: - dataset.add_frame(frame) - else: - frame_buffer.append(frame) - record_tick += 1 - else: - # Sync inference - if interpolator.needs_new_action(): - device = get_safe_torch_device(cfg.device) - action_tensor = predict_action( - observation=obs_frame, - policy=ctx.policy, - device=device, - preprocessor=ctx.preprocessor, - postprocessor=ctx.postprocessor, - use_amp=ctx.policy.config.use_amp, - task=task_str, - robot_type=robot.robot_type, - ) - robot_action = make_robot_action(action_tensor, dataset.features) - action_t = torch.tensor([robot_action[k] for k in dataset_action_keys]) - interpolator.add(action_t) - - interp = interpolator.get() - if interp is not None: - robot_action = {k: interp[i].item() for i, k in enumerate(dataset_action_keys)} - processed = ctx.robot_action_processor((robot_action, obs)) - robot.send_action(processed) - last_action = processed - action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION) - if record_tick % record_stride == 0: - frame = {**obs_frame, **action_frame, "task": task_str} - if stream_online: - dataset.add_frame(frame) - else: - frame_buffer.append(frame) - record_tick += 1 + if stream_online: + dataset.add_frame(frame) + else: + frame_buffer.append(frame) + record_tick += 1 dt = time.perf_counter() - loop_start if (sleep_t := control_interval - dt) > 0: diff --git a/src/lerobot/rollout/strategies/highlight.py b/src/lerobot/rollout/strategies/highlight.py index aa33d851d..3e772470d 100644 --- a/src/lerobot/rollout/strategies/highlight.py +++ b/src/lerobot/rollout/strategies/highlight.py @@ -16,21 +16,22 @@ from __future__ import annotations +import contextlib import logging import time +from threading import Event as ThreadingEvent from lerobot.datasets import VideoEncodingManager from lerobot.policies.rtc import ActionInterpolator -from lerobot.policies.utils import make_robot_action from lerobot.utils.constants import ACTION, OBS_STR from lerobot.utils.feature_utils import build_dataset_frame from lerobot.utils.robot_utils import precise_sleep from ..configs import HighlightStrategyConfig from ..context import RolloutContext -from ..inference import InferenceEngine, _resolve_action_key_order +from ..inference import InferenceEngine from ..ring_buffer import RolloutRingBuffer -from . import RolloutStrategy +from . import RolloutStrategy, infer_action logger = logging.getLogger(__name__) @@ -54,13 +55,14 @@ class HighlightStrategy(RolloutStrategy): def __init__(self, config: HighlightStrategyConfig): super().__init__(config) self._engine: InferenceEngine | None = None + self._interpolator: ActionInterpolator | None = None self._ring: RolloutRingBuffer | None = None self._listener = None - self._save_requested = False - self._recording_live = False + self._save_requested = ThreadingEvent() + self._recording_live = ThreadingEvent() def setup(self, ctx: RolloutContext) -> None: - interpolator = ActionInterpolator(multiplier=ctx.cfg.interpolation_multiplier) + self._interpolator = ActionInterpolator(multiplier=ctx.cfg.interpolation_multiplier) self._engine = InferenceEngine( policy=ctx.policy, @@ -73,7 +75,6 @@ class HighlightStrategy(RolloutStrategy): task=ctx.cfg.task, fps=ctx.cfg.fps, device=ctx.cfg.device, - interpolator=interpolator, use_torch_compile=ctx.cfg.use_torch_compile, compile_warmup_inferences=ctx.cfg.compile_warmup_inferences, ) @@ -97,17 +98,12 @@ class HighlightStrategy(RolloutStrategy): cfg = ctx.cfg robot = ctx.robot_wrapper dataset = ctx.dataset - action_keys = ctx.action_keys ring = self._ring + interpolator = self._interpolator - interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier) control_interval = interpolator.get_control_interval(cfg.fps) - - policy_action_names = getattr(cfg.policy, "action_feature_names", None) - ordered_keys = _resolve_action_key_order( - list(policy_action_names) if policy_action_names else None, - action_keys, - ) + ordered_keys = ctx.ordered_action_keys + features = dataset.features if engine.is_rtc: engine.resume() @@ -126,70 +122,58 @@ class HighlightStrategy(RolloutStrategy): obs = robot.get_observation() obs_processed = ctx.robot_observation_processor(obs) - action_dict = None if engine.is_rtc: engine.update_observation(obs_processed) - if cfg.use_torch_compile and not engine.compile_warmup_done.is_set(): - dt = time.perf_counter() - loop_start - if (sleep_t := control_interval - dt) > 0: - precise_sleep(sleep_t) - continue + if cfg.use_torch_compile and not engine.compile_warmup_done.is_set(): + dt = time.perf_counter() - loop_start + if (sleep_t := control_interval - dt) > 0: + precise_sleep(sleep_t) + continue - if cfg.use_torch_compile and not warmup_flushed: - engine.reset() - interpolator.reset() - warmup_flushed = True + if cfg.use_torch_compile and not warmup_flushed: + engine.reset() + interpolator.reset() + warmup_flushed = True + if engine.is_rtc: + engine.resume() - if interpolator.needs_new_action(): - action_tensor = engine.consume_rtc_action() - if action_tensor is not None: - interpolator.add(action_tensor.cpu()) - - interp = interpolator.get() - if interp is not None: - action_dict = { - k: interp[i].item() for i, k in enumerate(ordered_keys) if i < len(interp) - } - processed = ctx.robot_action_processor((action_dict, obs)) - robot.send_action(processed) - else: - obs_frame = build_dataset_frame(ctx.dataset_features, obs_processed, prefix=OBS_STR) - action_tensor = engine.get_action_sync(obs_frame) - action_dict = make_robot_action(action_tensor, ctx.dataset_features) - processed = ctx.robot_action_processor((action_dict, obs)) - robot.send_action(processed) + action_dict = infer_action( + engine, obs_processed, obs, ctx, interpolator, ordered_keys, features + ) # Build frame for ring buffer / live recording if action_dict is not None: - obs_frame = build_dataset_frame(ctx.dataset_features, obs_processed, prefix=OBS_STR) - action_frame = build_dataset_frame(ctx.dataset_features, action_dict, prefix=ACTION) + obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR) + action_frame = build_dataset_frame(features, action_dict, prefix=ACTION) frame = {**obs_frame, **action_frame, "task": task_str} # Handle save key toggle - if self._save_requested: - self._save_requested = False - if not self._recording_live: + if self._save_requested.is_set(): + self._save_requested.clear() + if not self._recording_live.is_set(): logger.info( "Flushing ring buffer (%d frames) + starting live recording", len(ring) ) for buffered_frame in ring.drain(): dataset.add_frame(buffered_frame) - self._recording_live = True + self._recording_live.set() else: + # Save current frame as the last frame of the episode dataset.add_frame(frame) dataset.save_episode() logger.info("Episode saved") - self._recording_live = False + self._recording_live.clear() engine.reset() interpolator.reset() if engine.is_rtc: engine.resume() - if self._recording_live: + if self._recording_live.is_set(): dataset.add_frame(frame) else: + # Current frame goes into the ring buffer for next potential save. ring.append(frame) dt = time.perf_counter() - loop_start @@ -197,11 +181,9 @@ class HighlightStrategy(RolloutStrategy): precise_sleep(sleep_t) finally: - if self._recording_live: - try: + if self._recording_live.is_set(): + with contextlib.suppress(Exception): dataset.save_episode() - except Exception: - pass def teardown(self, ctx: RolloutContext) -> None: if self._engine is not None: @@ -237,13 +219,11 @@ class HighlightStrategy(RolloutStrategy): save_key = self.config.save_key def on_press(key): - try: + with contextlib.suppress(Exception): if hasattr(key, "char") and key.char == save_key: - self._save_requested = True + self._save_requested.set() elif key == keyboard.Key.esc: - self._save_requested = False - except Exception: - pass + self._save_requested.clear() self._listener = keyboard.Listener(on_press=on_press) self._listener.start() diff --git a/src/lerobot/rollout/strategies/sentry.py b/src/lerobot/rollout/strategies/sentry.py index de59c0d5c..3b36cd0dd 100644 --- a/src/lerobot/rollout/strategies/sentry.py +++ b/src/lerobot/rollout/strategies/sentry.py @@ -16,21 +16,21 @@ from __future__ import annotations +import contextlib import logging import time from threading import Thread from lerobot.datasets import VideoEncodingManager from lerobot.policies.rtc import ActionInterpolator -from lerobot.policies.utils import make_robot_action from lerobot.utils.constants import ACTION, OBS_STR from lerobot.utils.feature_utils import build_dataset_frame from lerobot.utils.robot_utils import precise_sleep from ..configs import SentryStrategyConfig from ..context import RolloutContext -from ..inference import InferenceEngine, _resolve_action_key_order -from . import RolloutStrategy +from ..inference import InferenceEngine +from . import RolloutStrategy, infer_action logger = logging.getLogger(__name__) @@ -55,10 +55,12 @@ class SentryStrategy(RolloutStrategy): def __init__(self, config: SentryStrategyConfig): super().__init__(config) self._engine: InferenceEngine | None = None + self._interpolator: ActionInterpolator | None = None self._push_thread: Thread | None = None + self._needs_push: bool = False def setup(self, ctx: RolloutContext) -> None: - interpolator = ActionInterpolator(multiplier=ctx.cfg.interpolation_multiplier) + self._interpolator = ActionInterpolator(multiplier=ctx.cfg.interpolation_multiplier) self._engine = InferenceEngine( policy=ctx.policy, @@ -71,7 +73,6 @@ class SentryStrategy(RolloutStrategy): task=ctx.cfg.task, fps=ctx.cfg.fps, device=ctx.cfg.device, - interpolator=interpolator, use_torch_compile=ctx.cfg.use_torch_compile, compile_warmup_inferences=ctx.cfg.compile_warmup_inferences, ) @@ -87,16 +88,11 @@ class SentryStrategy(RolloutStrategy): cfg = ctx.cfg robot = ctx.robot_wrapper dataset = ctx.dataset - action_keys = ctx.action_keys + interpolator = self._interpolator - interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier) control_interval = interpolator.get_control_interval(cfg.fps) - - policy_action_names = getattr(cfg.policy, "action_feature_names", None) - ordered_keys = _resolve_action_key_order( - list(policy_action_names) if policy_action_names else None, - action_keys, - ) + ordered_keys = ctx.ordered_action_keys + features = dataset.features if engine.is_rtc: engine.resume() @@ -117,45 +113,31 @@ class SentryStrategy(RolloutStrategy): obs = robot.get_observation() obs_processed = ctx.robot_observation_processor(obs) - action_dict = None if engine.is_rtc: engine.update_observation(obs_processed) - if cfg.use_torch_compile and not engine.compile_warmup_done.is_set(): - dt = time.perf_counter() - loop_start - if (sleep_t := control_interval - dt) > 0: - precise_sleep(sleep_t) - continue + if cfg.use_torch_compile and not engine.compile_warmup_done.is_set(): + dt = time.perf_counter() - loop_start + if (sleep_t := control_interval - dt) > 0: + precise_sleep(sleep_t) + continue - if cfg.use_torch_compile and not warmup_flushed: - engine.reset() - interpolator.reset() - warmup_flushed = True + if cfg.use_torch_compile and not warmup_flushed: + engine.reset() + interpolator.reset() + warmup_flushed = True + if engine.is_rtc: + engine.resume() - if interpolator.needs_new_action(): - action_tensor = engine.consume_rtc_action() - if action_tensor is not None: - interpolator.add(action_tensor.cpu()) - - interp = interpolator.get() - if interp is not None: - action_dict = { - k: interp[i].item() for i, k in enumerate(ordered_keys) if i < len(interp) - } - processed = ctx.robot_action_processor((action_dict, obs)) - robot.send_action(processed) - else: - obs_frame = build_dataset_frame(ctx.dataset_features, obs_processed, prefix=OBS_STR) - action_tensor = engine.get_action_sync(obs_frame) - action_dict = make_robot_action(action_tensor, ctx.dataset_features) - processed = ctx.robot_action_processor((action_dict, obs)) - robot.send_action(processed) + action_dict = infer_action( + engine, obs_processed, obs, ctx, interpolator, ordered_keys, features + ) # Record frame if action_dict is not None: - obs_frame = build_dataset_frame(ctx.dataset_features, obs_processed, prefix=OBS_STR) - action_frame = build_dataset_frame(ctx.dataset_features, action_dict, prefix=ACTION) + obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR) + action_frame = build_dataset_frame(features, action_dict, prefix=ACTION) frame = {**obs_frame, **action_frame, "task": task_str} dataset.add_frame(frame) @@ -164,6 +146,7 @@ class SentryStrategy(RolloutStrategy): if elapsed >= self.config.episode_duration_s: dataset.save_episode() episodes_since_push += 1 + self._needs_push = True logger.info("Episode saved (total: %d)", dataset.num_episodes) if episodes_since_push >= self.config.upload_every_n_episodes: @@ -181,26 +164,27 @@ class SentryStrategy(RolloutStrategy): precise_sleep(sleep_t) finally: - try: + with contextlib.suppress(Exception): dataset.save_episode() - except Exception: - pass + self._needs_push = True def teardown(self, ctx: RolloutContext) -> None: if self._engine is not None: self._engine.stop() + # Wait for any in-flight background push + if self._push_thread is not None and self._push_thread.is_alive(): + self._push_thread.join(timeout=60) + if ctx.dataset is not None: ctx.dataset.finalize() - if ctx.cfg.dataset and ctx.cfg.dataset.push_to_hub: + # Only push if there are unsaved changes since last background push + if self._needs_push and ctx.cfg.dataset and ctx.cfg.dataset.push_to_hub: ctx.dataset.push_to_hub( tags=ctx.cfg.dataset.tags, private=ctx.cfg.dataset.private, ) - if self._push_thread is not None and self._push_thread.is_alive(): - self._push_thread.join(timeout=60) - if ctx.robot.is_connected: ctx.robot.disconnect() if ctx.teleop is not None and ctx.teleop.is_connected: @@ -219,6 +203,7 @@ class SentryStrategy(RolloutStrategy): tags=cfg.dataset.tags if cfg.dataset else None, private=cfg.dataset.private if cfg.dataset else False, ) + self._needs_push = False logger.info("Background push to hub complete") except Exception as e: logger.error("Background push failed: %s", e)