address review

This commit is contained in:
Steven Palma
2026-04-15 19:31:53 +02:00
parent edd7fc52a8
commit 4e3175ff15
12 changed files with 139 additions and 77 deletions
+9 -9
View File
@@ -68,12 +68,12 @@ The `lerobot-rollout --strategy.type=dagger` mode requires **teleoperators with
## Script ## Script
Use `lerobot-rollout` with `--strategy.type=dagger` for HIL data collection. Toggle RTC with `--rtc.enabled=true`: Use `lerobot-rollout` with `--strategy.type=dagger` for HIL data collection. Select the inference backend with `--inference.type=sync|rtc`:
| Mode | Flag | Models | | Mode | Flag | Models |
| ------------------------ | -------------------- | --------------------- | | ------------------------ | ---------------------- | --------------------- |
| Standard (default) | _(no flag needed)_ | ACT, Diffusion Policy | | Standard (default) | _(no flag needed)_ | ACT, Diffusion Policy |
| Real-Time Chunking (RTC) | `--rtc.enabled=true` | Pi0, Pi0.5, SmolVLA | | Real-Time Chunking (RTC) | `--inference.type=rtc` | Pi0, Pi0.5, SmolVLA |
--- ---
@@ -122,10 +122,10 @@ For models with high inference latency, enable RTC for smooth execution:
```bash ```bash
lerobot-rollout --strategy.type=dagger \ lerobot-rollout --strategy.type=dagger \
--rtc.enabled=true \ --inference.type=rtc \
--rtc.execution_horizon=20 \ --inference.rtc.execution_horizon=20 \
--rtc.max_guidance_weight=5.0 \ --inference.rtc.max_guidance_weight=5.0 \
--rtc.prefix_attention_schedule=LINEAR \ --inference.rtc.prefix_attention_schedule=LINEAR \
--robot.type=bi_openarm_follower \ --robot.type=bi_openarm_follower \
--robot.left_arm_config.port=can1 \ --robot.left_arm_config.port=can1 \
--robot.left_arm_config.side=left \ --robot.left_arm_config.side=left \
+1 -1
View File
@@ -542,4 +542,4 @@ The `--strategy.type` flag selects the execution mode:
- `highlight`: Ring buffer recording with keystroke save (useful for capturing interesting events) - `highlight`: Ring buffer recording with keystroke save (useful for capturing interesting events)
- `dagger`: Human-in-the-loop data collection (see [HIL Data Collection](./hil_data_collection)) - `dagger`: Human-in-the-loop data collection (see [HIL Data Collection](./hil_data_collection))
All strategies support `--rtc.enabled=true` for smooth execution with slow VLA models (Pi0, Pi0.5, SmolVLA). All strategies support `--inference.type=rtc` for smooth execution with slow VLA models (Pi0, Pi0.5, SmolVLA).
+4 -4
View File
@@ -34,7 +34,7 @@ pip install -e ".[smolvla]"
### Using RTC with Pi0 ### Using RTC with Pi0
You can use `lerobot-rollout --strategy.type=base --rtc.enabled=true` for RTC deployment on real robots. You can use `lerobot-rollout --strategy.type=base --inference.type=rtc` for RTC deployment on real robots.
The snippet below provides a simplified pseudo-example of how RTC operates with Pi0 in your pipeline: The snippet below provides a simplified pseudo-example of how RTC operates with Pi0 in your pipeline:
```python ```python
@@ -140,9 +140,9 @@ The script generates a visualization of the denoising process, comparing standar
lerobot-rollout \ lerobot-rollout \
--strategy.type=base \ --strategy.type=base \
--policy.path=${HF_USERNAME}/policy_repo_id \ --policy.path=${HF_USERNAME}/policy_repo_id \
--rtc.enabled=true \ --inference.type=rtc \
--rtc.execution_horizon=10 \ --inference.rtc.execution_horizon=10 \
--rtc.max_guidance_weight=10.0 \ --inference.rtc.max_guidance_weight=10.0 \
--robot.type=so100_follower \ --robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58FA0834591 \ --robot.port=/dev/tty.usbmodem58FA0834591 \
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \ --robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
+1 -1
View File
@@ -284,7 +284,7 @@ python examples/rtc/eval_with_real_robot.py \
--task="task_description" \ --task="task_description" \
--duration=1000 \ --duration=1000 \
--fps=30 \ --fps=30 \
--rtc.enabled=true --inference.type=rtc
``` ```
--- ---
+17 -10
View File
@@ -242,18 +242,25 @@ def build_rollout_context(
} }
action_features_hw = robot.action_features action_features_hw = robot.action_features
dataset_features = combine_feature_dicts( # The action side is always needed: sync inference reads action names from
aggregate_pipeline_dataset_features( # ``dataset_features[ACTION]`` to map policy tensors back to robot actions.
pipeline=teleop_action_processor, action_dataset_features = aggregate_pipeline_dataset_features(
initial_features=create_initial_features(action=action_features_hw), pipeline=teleop_action_processor,
use_videos=cfg.dataset.video if cfg.dataset else True, initial_features=create_initial_features(action=action_features_hw),
), use_videos=cfg.dataset.video if cfg.dataset else True,
aggregate_pipeline_dataset_features( )
# Observation-side aggregation only feeds the dataset schema; skip it for
# the base strategy to avoid running observation pipelines that may have
# dataset-specific dependencies.
if cfg.dataset is not None:
observation_dataset_features = aggregate_pipeline_dataset_features(
pipeline=robot_observation_processor, pipeline=robot_observation_processor,
initial_features=create_initial_features(observation=observation_features_hw), initial_features=create_initial_features(observation=observation_features_hw),
use_videos=cfg.dataset.video if cfg.dataset else True, use_videos=cfg.dataset.video,
), )
) dataset_features = combine_feature_dicts(action_dataset_features, observation_dataset_features)
else:
dataset_features = action_dataset_features
hw_features = hw_to_dataset_features(observation_features_hw, "observation") hw_features = hw_to_dataset_features(observation_features_hw, "observation")
raw_action_keys = list(robot.action_features.keys()) raw_action_keys = list(robot.action_features.keys())
policy_action_names = getattr(policy_config, "action_feature_names", None) policy_action_names = getattr(policy_config, "action_feature_names", None)
+3
View File
@@ -68,6 +68,9 @@ class SyncInferenceConfig(InferenceStrategyConfig):
class RTCInferenceConfig(InferenceStrategyConfig): class RTCInferenceConfig(InferenceStrategyConfig):
"""Real-Time Chunking: async policy inference in a background thread.""" """Real-Time Chunking: async policy inference in a background thread."""
# ``RTCConfig`` is a small dataclass with default-only fields, so eagerly
# constructing one here costs nothing and keeps draccus' CLI surface flat
# (``--inference.rtc.execution_horizon=...`` etc.). No need to lazy-init.
rtc: RTCConfig = field(default_factory=RTCConfig) rtc: RTCConfig = field(default_factory=RTCConfig)
queue_threshold: int = 30 queue_threshold: int = 30
+26 -6
View File
@@ -51,6 +51,15 @@ from .base import InferenceStrategy
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# How long the RTC loop sleeps when paused, idle, or backpressured by a full queue.
_RTC_IDLE_SLEEP_S: float = 0.01
# Backoff between transient inference errors (per consecutive failure).
_RTC_ERROR_RETRY_DELAY_S: float = 0.5
# Consecutive transient errors tolerated before giving up and propagating shutdown.
_RTC_MAX_CONSECUTIVE_ERRORS: int = 10
# Hard timeout for joining the RTC thread on stop().
_RTC_JOIN_TIMEOUT_S: float = 3.0
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# RTC helpers (extracted from examples/rtc and examples/hil) # RTC helpers (extracted from examples/rtc and examples/hil)
@@ -208,7 +217,7 @@ class RTCInferenceStrategy(InferenceStrategy):
self._shutdown_event.set() self._shutdown_event.set()
self._policy_active.clear() self._policy_active.clear()
if self._rtc_thread is not None and self._rtc_thread.is_alive(): if self._rtc_thread is not None and self._rtc_thread.is_alive():
self._rtc_thread.join(timeout=3.0) self._rtc_thread.join(timeout=_RTC_JOIN_TIMEOUT_S)
self._rtc_thread = None self._rtc_thread = None
def pause(self) -> None: def pause(self) -> None:
@@ -252,17 +261,18 @@ class RTCInferenceStrategy(InferenceStrategy):
warmup_required = max(1, self._compile_warmup_inferences) if self._use_torch_compile else 0 warmup_required = max(1, self._compile_warmup_inferences) if self._use_torch_compile else 0
inference_count = 0 inference_count = 0
consecutive_errors = 0
while not self._shutdown_event.is_set(): while not self._shutdown_event.is_set():
if not self._policy_active.is_set(): if not self._policy_active.is_set():
time.sleep(0.01) time.sleep(_RTC_IDLE_SLEEP_S)
continue continue
queue = self._action_queue queue = self._action_queue
with self._obs_lock: with self._obs_lock:
obs = self._obs_holder.get("obs") obs = self._obs_holder.get("obs")
if queue is None or obs is None: if queue is None or obs is None:
time.sleep(0.01) time.sleep(_RTC_IDLE_SLEEP_S)
continue continue
if queue.qsize() <= self._rtc_queue_threshold: if queue.qsize() <= self._rtc_queue_threshold:
@@ -310,6 +320,7 @@ class RTCInferenceStrategy(InferenceStrategy):
new_delay = math.ceil(new_latency / time_per_chunk) new_delay = math.ceil(new_latency / time_per_chunk)
inference_count += 1 inference_count += 1
consecutive_errors = 0
is_warmup = self._use_torch_compile and inference_count <= warmup_required is_warmup = self._use_torch_compile and inference_count <= warmup_required
if is_warmup: if is_warmup:
latency_tracker.reset() latency_tracker.reset()
@@ -329,11 +340,20 @@ class RTCInferenceStrategy(InferenceStrategy):
logger.debug("RTC inference latency=%.2fs, queue=%d", new_latency, queue.qsize()) logger.debug("RTC inference latency=%.2fs, queue=%d", new_latency, queue.qsize())
except Exception as e: except Exception as e:
logger.error("RTC inference error: %s", e) consecutive_errors += 1
logger.error(
"RTC inference error (%d/%d): %s",
consecutive_errors,
_RTC_MAX_CONSECUTIVE_ERRORS,
e,
)
logger.debug(traceback.format_exc()) logger.debug(traceback.format_exc())
time.sleep(0.5) if consecutive_errors >= _RTC_MAX_CONSECUTIVE_ERRORS:
# Persistent failure: stop retrying and propagate shutdown.
raise
time.sleep(_RTC_ERROR_RETRY_DELAY_S)
else: else:
time.sleep(0.01) time.sleep(_RTC_IDLE_SLEEP_S)
except Exception as e: except Exception as e:
logger.error("Fatal error in RTC thread: %s", e) logger.error("Fatal error in RTC thread: %s", e)
+12 -9
View File
@@ -17,6 +17,7 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from contextlib import nullcontext
from copy import copy from copy import copy
import torch import torch
@@ -55,7 +56,7 @@ class SyncInferenceStrategy(InferenceStrategy):
self._dataset_features = dataset_features self._dataset_features = dataset_features
self._ordered_action_keys = ordered_action_keys self._ordered_action_keys = ordered_action_keys
self._task = task self._task = task
self._device = device or "cpu" self._device = torch.device(device or "cpu")
self._robot_type = robot_type self._robot_type = robot_type
def start(self) -> None: def start(self) -> None:
@@ -72,16 +73,18 @@ class SyncInferenceStrategy(InferenceStrategy):
def get_action(self, obs_frame: dict | None) -> torch.Tensor | None: def get_action(self, obs_frame: dict | None) -> torch.Tensor | None:
if obs_frame is None: if obs_frame is None:
return None return None
# Shallow copy is intentional: the caller (`send_next_action`) builds
# ``obs_frame`` fresh per tick via ``build_dataset_frame``, so the
# tensor/array values are not shared with any other reader.
observation = copy(obs_frame) observation = copy(obs_frame)
policy_device = torch.device(self._device) autocast_ctx = (
with ( torch.autocast(device_type=self._device.type)
torch.inference_mode(), if self._device.type == "cuda" and self._policy.config.use_amp
torch.autocast(device_type=policy_device.type) else nullcontext()
if policy_device.type == "cuda" and self._policy.config.use_amp )
else torch.inference_mode(), with torch.inference_mode(), autocast_ctx:
):
observation = prepare_observation_for_inference( observation = prepare_observation_for_inference(
observation, policy_device, self._task, self._robot_type observation, self._device, self._task, self._robot_type
) )
observation = self._preprocessor(observation) observation = self._preprocessor(observation)
action = self._policy.select_action(observation) action = self._policy.select_action(observation)
+12 -1
View File
@@ -19,6 +19,7 @@ from __future__ import annotations
from collections import deque from collections import deque
import numpy as np import numpy as np
import torch
class RolloutRingBuffer: class RolloutRingBuffer:
@@ -28,6 +29,12 @@ class RolloutRingBuffer:
time (``max_frames``) and memory (``max_memory_bytes``). When either time (``max_frames``) and memory (``max_memory_bytes``). When either
limit is reached the oldest frames are evicted. limit is reached the oldest frames are evicted.
.. note::
This class is **single-threaded**. ``append``/``drain``/``clear``
must all be called from the same thread (the rollout main loop).
Concurrent access from a background thread will corrupt
``_current_bytes`` accounting.
Parameters Parameters
---------- ----------
max_seconds: max_seconds:
@@ -91,7 +98,11 @@ def _estimate_frame_bytes(frame: dict) -> int:
"""Rough byte estimate for a single frame dictionary.""" """Rough byte estimate for a single frame dictionary."""
total = 0 total = 0
for v in frame.values(): for v in frame.values():
if isinstance(v, np.ndarray) or hasattr(v, "nbytes"): if isinstance(v, torch.Tensor):
# ``torch.Tensor`` has no ``nbytes``; compute it explicitly so the
# memory cap is honoured even when frames hold unconverted tensors.
total += v.nelement() * v.element_size()
elif isinstance(v, np.ndarray) or hasattr(v, "nbytes"):
total += v.nbytes total += v.nbytes
elif isinstance(v, (int, float)): elif isinstance(v, (int, float)):
total += 8 total += 8
+38 -36
View File
@@ -33,7 +33,7 @@ import contextlib
import enum import enum
import logging import logging
import time import time
from threading import Lock from threading import Event, Lock
from typing import Any from typing import Any
import numpy as np import numpy as np
@@ -90,14 +90,16 @@ class DAggerEvents:
self._phase = DAggerPhase.AUTONOMOUS self._phase = DAggerPhase.AUTONOMOUS
self._pending_transition: str | None = None self._pending_transition: str | None = None
# Episode-level flags (written by keyboard, consumed by main loop) # Episode-level flags written by keyboard/pedal threads, consumed by
self.exit_early: bool = False # the main loop. ``threading.Event`` gives us atomic set/clear/check
self.rerecord_episode: bool = False # semantics without taking ``self._lock``.
self.stop_recording: bool = False self.exit_early = Event()
self.rerecord_episode = Event()
self.stop_recording = Event()
# Reset-phase flags (simpler lifecycle, shared between threads) # Reset-phase flags (simpler lifecycle, shared between threads).
self.in_reset: bool = False self.in_reset = Event()
self.start_next_episode: bool = False self.start_next_episode = Event()
# -- Thread-safe phase access ------------------------------------------ # -- Thread-safe phase access ------------------------------------------
@@ -140,8 +142,8 @@ class DAggerEvents:
with self._lock: with self._lock:
self._phase = DAggerPhase.AUTONOMOUS self._phase = DAggerPhase.AUTONOMOUS
self._pending_transition = None self._pending_transition = None
self.exit_early = False self.exit_early.clear()
self.rerecord_episode = False self.rerecord_episode.clear()
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -198,24 +200,24 @@ def _reset_loop(
"""Reset period where the human repositions the environment.""" """Reset period where the human repositions the environment."""
logger.info("RESET — press any key to enable teleoperation") logger.info("RESET — press any key to enable teleoperation")
events.in_reset = True events.in_reset.set()
events.start_next_episode = False events.start_next_episode.clear()
obs = robot.get_observation() obs = robot.get_observation()
robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features} robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features}
_teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50) _teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
while not events.start_next_episode and not events.stop_recording: while not events.start_next_episode.is_set() and not events.stop_recording.is_set():
precise_sleep(0.05) precise_sleep(0.05)
if events.stop_recording: if events.stop_recording.is_set():
return return
events.start_next_episode = False events.start_next_episode.clear()
_teleop_disable_torque(teleop) _teleop_disable_torque(teleop)
logger.info("Teleop enabled — press any key to start episode") logger.info("Teleop enabled — press any key to start episode")
while not events.start_next_episode and not events.stop_recording: while not events.start_next_episode.is_set() and not events.stop_recording.is_set():
loop_start = time.perf_counter() loop_start = time.perf_counter()
obs = robot.get_observation() obs = robot.get_observation()
action = teleop.get_action() action = teleop.get_action()
@@ -224,8 +226,8 @@ def _reset_loop(
robot.send_action(robot_action_to_send) robot.send_action(robot_action_to_send)
precise_sleep(1 / fps - (time.perf_counter() - loop_start)) precise_sleep(1 / fps - (time.perf_counter() - loop_start))
events.in_reset = False events.in_reset.clear()
events.start_next_episode = False events.start_next_episode.clear()
events.reset_for_episode() events.reset_for_episode()
@@ -242,16 +244,16 @@ def _init_dagger_keyboard(events: DAggerEvents):
def on_press(key): def on_press(key):
try: try:
if events.in_reset: if events.in_reset.is_set():
if ( if (
key in [keyboard.Key.space, keyboard.Key.right] key in [keyboard.Key.space, keyboard.Key.right]
or hasattr(key, "char") or hasattr(key, "char")
and key.char == "c" and key.char == "c"
): ):
events.start_next_episode = True events.start_next_episode.set()
elif key == keyboard.Key.esc: elif key == keyboard.Key.esc:
events.stop_recording = True events.stop_recording.set()
events.start_next_episode = True events.start_next_episode.set()
return return
phase = events.phase phase = events.phase
@@ -275,15 +277,15 @@ def _init_dagger_keyboard(events: DAggerEvents):
elif key == keyboard.Key.right: elif key == keyboard.Key.right:
logger.info("End episode") logger.info("End episode")
events.exit_early = True events.exit_early.set()
elif key == keyboard.Key.left: elif key == keyboard.Key.left:
logger.info("Re-record episode") logger.info("Re-record episode")
events.rerecord_episode = True events.rerecord_episode.set()
events.exit_early = True events.exit_early.set()
elif key == keyboard.Key.esc: elif key == keyboard.Key.esc:
logger.info("Stop recording...") logger.info("Stop recording...")
events.stop_recording = True events.stop_recording.set()
events.exit_early = True events.exit_early.set()
except Exception as e: except Exception as e:
logger.debug("Key error: %s", e) logger.debug("Key error: %s", e)
@@ -301,8 +303,8 @@ def _dagger_pedal_callback(events: DAggerEvents):
def on_press(code: str) -> None: def on_press(code: str) -> None:
if code not in _DAGGER_PEDAL_KEYS: if code not in _DAGGER_PEDAL_KEYS:
return return
if events.in_reset: if events.in_reset.is_set():
events.start_next_episode = True events.start_next_episode.set()
return return
phase = events.phase phase = events.phase
if phase == DAggerPhase.CORRECTING: if phase == DAggerPhase.CORRECTING:
@@ -362,22 +364,22 @@ class DAggerStrategy(RolloutStrategy):
with VideoEncodingManager(dataset): with VideoEncodingManager(dataset):
try: try:
recorded = 0 recorded = 0
while recorded < self.config.num_episodes and not events.stop_recording: while recorded < self.config.num_episodes and not events.stop_recording.is_set():
log_say(f"Episode {dataset.num_episodes}", self.config.play_sounds) log_say(f"Episode {dataset.num_episodes}", self.config.play_sounds)
self._run_episode(ctx) self._run_episode(ctx)
if events.rerecord_episode: if events.rerecord_episode.is_set():
log_say("Re-recording", self.config.play_sounds) log_say("Re-recording", self.config.play_sounds)
events.rerecord_episode = False events.rerecord_episode.clear()
events.exit_early = False events.exit_early.clear()
dataset.clear_episode_buffer() dataset.clear_episode_buffer()
continue continue
dataset.save_episode() dataset.save_episode()
recorded += 1 recorded += 1
if recorded < self.config.num_episodes and not events.stop_recording: if recorded < self.config.num_episodes and not events.stop_recording.is_set():
_reset_loop( _reset_loop(
ctx.hardware.robot_wrapper, ctx.hardware.robot_wrapper,
teleop, teleop,
@@ -448,8 +450,8 @@ class DAggerStrategy(RolloutStrategy):
while timestamp < self.config.episode_time_s: while timestamp < self.config.episode_time_s:
loop_start = time.perf_counter() loop_start = time.perf_counter()
if events.exit_early: if events.exit_early.is_set():
events.exit_early = False events.exit_early.clear()
break break
transition = events.consume_transition() transition = events.consume_transition()
@@ -118,6 +118,14 @@ class HighlightStrategy(RolloutStrategy):
action_frame = build_dataset_frame(features, action_dict, prefix=ACTION) action_frame = build_dataset_frame(features, action_dict, prefix=ACTION)
frame = {**obs_frame, **action_frame, "task": task_str} frame = {**obs_frame, **action_frame, "task": task_str}
# NOTE: ``is_set()`` then ``clear()`` is not atomic
# against the keyboard thread setting the flag again
# in between — but that is benign: we lose at most one
# toggle, processed on the next iteration. The
# ``_recording_live`` branch below is reached in the
# SAME iteration after ``clear()`` runs, so a frame
# finalised by ``save_episode()`` is never re-added to
# the next episode.
if self._save_requested.is_set(): if self._save_requested.is_set():
self._save_requested.clear() self._save_requested.clear()
if not self._recording_live.is_set(): if not self._recording_live.is_set():
+8
View File
@@ -108,10 +108,18 @@ class SentryStrategy(RolloutStrategy):
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR) obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
action_frame = build_dataset_frame(features, action_dict, prefix=ACTION) action_frame = build_dataset_frame(features, action_dict, prefix=ACTION)
frame = {**obs_frame, **action_frame, "task": task_str} frame = {**obs_frame, **action_frame, "task": task_str}
# ``add_frame`` writes to the in-progress episode buffer; the
# background pusher only ever touches *finalised* episode
# artifacts on disk. The two operate on disjoint state, so
# ``add_frame`` does not need ``_episode_lock``.
dataset.add_frame(frame) dataset.add_frame(frame)
elapsed = time.perf_counter() - episode_start elapsed = time.perf_counter() - episode_start
if elapsed >= self.config.episode_duration_s: if elapsed >= self.config.episode_duration_s:
# ``save_episode`` finalises the in-progress episode and
# flushes it to disk; ``_episode_lock`` serialises this with
# ``push_to_hub`` (run in the background executor) so the
# pusher never reads a half-written episode.
with self._episode_lock: with self._episode_lock:
dataset.save_episode() dataset.save_episode()
episodes_since_push += 1 episodes_since_push += 1