address review

This commit is contained in:
Steven Palma
2026-04-15 19:31:53 +02:00
parent edd7fc52a8
commit 4e3175ff15
12 changed files with 139 additions and 77 deletions
+9 -9
View File
@@ -68,12 +68,12 @@ The `lerobot-rollout --strategy.type=dagger` mode requires **teleoperators with
## Script
Use `lerobot-rollout` with `--strategy.type=dagger` for HIL data collection. Toggle RTC with `--rtc.enabled=true`:
Use `lerobot-rollout` with `--strategy.type=dagger` for HIL data collection. Select the inference backend with `--inference.type=sync|rtc`:
| Mode | Flag | Models |
| ------------------------ | -------------------- | --------------------- |
| Standard (default) | _(no flag needed)_ | ACT, Diffusion Policy |
| Real-Time Chunking (RTC) | `--rtc.enabled=true` | Pi0, Pi0.5, SmolVLA |
| Mode | Flag | Models |
| ------------------------ | ---------------------- | --------------------- |
| Standard (default) | _(no flag needed)_ | ACT, Diffusion Policy |
| Real-Time Chunking (RTC) | `--inference.type=rtc` | Pi0, Pi0.5, SmolVLA |
---
@@ -122,10 +122,10 @@ For models with high inference latency, enable RTC for smooth execution:
```bash
lerobot-rollout --strategy.type=dagger \
--rtc.enabled=true \
--rtc.execution_horizon=20 \
--rtc.max_guidance_weight=5.0 \
--rtc.prefix_attention_schedule=LINEAR \
--inference.type=rtc \
--inference.rtc.execution_horizon=20 \
--inference.rtc.max_guidance_weight=5.0 \
--inference.rtc.prefix_attention_schedule=LINEAR \
--robot.type=bi_openarm_follower \
--robot.left_arm_config.port=can1 \
--robot.left_arm_config.side=left \
+1 -1
View File
@@ -542,4 +542,4 @@ The `--strategy.type` flag selects the execution mode:
- `highlight`: Ring buffer recording with keystroke save (useful for capturing interesting events)
- `dagger`: Human-in-the-loop data collection (see [HIL Data Collection](./hil_data_collection))
All strategies support `--rtc.enabled=true` for smooth execution with slow VLA models (Pi0, Pi0.5, SmolVLA).
All strategies support `--inference.type=rtc` for smooth execution with slow VLA models (Pi0, Pi0.5, SmolVLA).
+4 -4
View File
@@ -34,7 +34,7 @@ pip install -e ".[smolvla]"
### Using RTC with Pi0
You can use `lerobot-rollout --strategy.type=base --rtc.enabled=true` for RTC deployment on real robots.
You can use `lerobot-rollout --strategy.type=base --inference.type=rtc` for RTC deployment on real robots.
The snippet below provides a simplified pseudo-example of how RTC operates with Pi0 in your pipeline:
```python
@@ -140,9 +140,9 @@ The script generates a visualization of the denoising process, comparing standar
lerobot-rollout \
--strategy.type=base \
--policy.path=${HF_USERNAME}/policy_repo_id \
--rtc.enabled=true \
--rtc.execution_horizon=10 \
--rtc.max_guidance_weight=10.0 \
--inference.type=rtc \
--inference.rtc.execution_horizon=10 \
--inference.rtc.max_guidance_weight=10.0 \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58FA0834591 \
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
+1 -1
View File
@@ -284,7 +284,7 @@ python examples/rtc/eval_with_real_robot.py \
--task="task_description" \
--duration=1000 \
--fps=30 \
--rtc.enabled=true
--inference.type=rtc
```
---
+17 -10
View File
@@ -242,18 +242,25 @@ def build_rollout_context(
}
action_features_hw = robot.action_features
dataset_features = combine_feature_dicts(
aggregate_pipeline_dataset_features(
pipeline=teleop_action_processor,
initial_features=create_initial_features(action=action_features_hw),
use_videos=cfg.dataset.video if cfg.dataset else True,
),
aggregate_pipeline_dataset_features(
# The action side is always needed: sync inference reads action names from
# ``dataset_features[ACTION]`` to map policy tensors back to robot actions.
action_dataset_features = aggregate_pipeline_dataset_features(
pipeline=teleop_action_processor,
initial_features=create_initial_features(action=action_features_hw),
use_videos=cfg.dataset.video if cfg.dataset else True,
)
# Observation-side aggregation only feeds the dataset schema; skip it for
# the base strategy to avoid running observation pipelines that may have
# dataset-specific dependencies.
if cfg.dataset is not None:
observation_dataset_features = aggregate_pipeline_dataset_features(
pipeline=robot_observation_processor,
initial_features=create_initial_features(observation=observation_features_hw),
use_videos=cfg.dataset.video if cfg.dataset else True,
),
)
use_videos=cfg.dataset.video,
)
dataset_features = combine_feature_dicts(action_dataset_features, observation_dataset_features)
else:
dataset_features = action_dataset_features
hw_features = hw_to_dataset_features(observation_features_hw, "observation")
raw_action_keys = list(robot.action_features.keys())
policy_action_names = getattr(policy_config, "action_feature_names", None)
+3
View File
@@ -68,6 +68,9 @@ class SyncInferenceConfig(InferenceStrategyConfig):
class RTCInferenceConfig(InferenceStrategyConfig):
"""Real-Time Chunking: async policy inference in a background thread."""
# ``RTCConfig`` is a small dataclass with default-only fields, so eagerly
# constructing one here costs nothing and keeps draccus' CLI surface flat
# (``--inference.rtc.execution_horizon=...`` etc.). No need to lazy-init.
rtc: RTCConfig = field(default_factory=RTCConfig)
queue_threshold: int = 30
+26 -6
View File
@@ -51,6 +51,15 @@ from .base import InferenceStrategy
logger = logging.getLogger(__name__)
# How long the RTC loop sleeps when paused, idle, or backpressured by a full queue.
_RTC_IDLE_SLEEP_S: float = 0.01
# Backoff between transient inference errors (per consecutive failure).
_RTC_ERROR_RETRY_DELAY_S: float = 0.5
# Consecutive transient errors tolerated before giving up and propagating shutdown.
_RTC_MAX_CONSECUTIVE_ERRORS: int = 10
# Hard timeout for joining the RTC thread on stop().
_RTC_JOIN_TIMEOUT_S: float = 3.0
# ---------------------------------------------------------------------------
# RTC helpers (extracted from examples/rtc and examples/hil)
@@ -208,7 +217,7 @@ class RTCInferenceStrategy(InferenceStrategy):
self._shutdown_event.set()
self._policy_active.clear()
if self._rtc_thread is not None and self._rtc_thread.is_alive():
self._rtc_thread.join(timeout=3.0)
self._rtc_thread.join(timeout=_RTC_JOIN_TIMEOUT_S)
self._rtc_thread = None
def pause(self) -> None:
@@ -252,17 +261,18 @@ class RTCInferenceStrategy(InferenceStrategy):
warmup_required = max(1, self._compile_warmup_inferences) if self._use_torch_compile else 0
inference_count = 0
consecutive_errors = 0
while not self._shutdown_event.is_set():
if not self._policy_active.is_set():
time.sleep(0.01)
time.sleep(_RTC_IDLE_SLEEP_S)
continue
queue = self._action_queue
with self._obs_lock:
obs = self._obs_holder.get("obs")
if queue is None or obs is None:
time.sleep(0.01)
time.sleep(_RTC_IDLE_SLEEP_S)
continue
if queue.qsize() <= self._rtc_queue_threshold:
@@ -310,6 +320,7 @@ class RTCInferenceStrategy(InferenceStrategy):
new_delay = math.ceil(new_latency / time_per_chunk)
inference_count += 1
consecutive_errors = 0
is_warmup = self._use_torch_compile and inference_count <= warmup_required
if is_warmup:
latency_tracker.reset()
@@ -329,11 +340,20 @@ class RTCInferenceStrategy(InferenceStrategy):
logger.debug("RTC inference latency=%.2fs, queue=%d", new_latency, queue.qsize())
except Exception as e:
logger.error("RTC inference error: %s", e)
consecutive_errors += 1
logger.error(
"RTC inference error (%d/%d): %s",
consecutive_errors,
_RTC_MAX_CONSECUTIVE_ERRORS,
e,
)
logger.debug(traceback.format_exc())
time.sleep(0.5)
if consecutive_errors >= _RTC_MAX_CONSECUTIVE_ERRORS:
# Persistent failure: stop retrying and propagate shutdown.
raise
time.sleep(_RTC_ERROR_RETRY_DELAY_S)
else:
time.sleep(0.01)
time.sleep(_RTC_IDLE_SLEEP_S)
except Exception as e:
logger.error("Fatal error in RTC thread: %s", e)
+12 -9
View File
@@ -17,6 +17,7 @@
from __future__ import annotations
import logging
from contextlib import nullcontext
from copy import copy
import torch
@@ -55,7 +56,7 @@ class SyncInferenceStrategy(InferenceStrategy):
self._dataset_features = dataset_features
self._ordered_action_keys = ordered_action_keys
self._task = task
self._device = device or "cpu"
self._device = torch.device(device or "cpu")
self._robot_type = robot_type
def start(self) -> None:
@@ -72,16 +73,18 @@ class SyncInferenceStrategy(InferenceStrategy):
def get_action(self, obs_frame: dict | None) -> torch.Tensor | None:
if obs_frame is None:
return None
# Shallow copy is intentional: the caller (`send_next_action`) builds
# ``obs_frame`` fresh per tick via ``build_dataset_frame``, so the
# tensor/array values are not shared with any other reader.
observation = copy(obs_frame)
policy_device = torch.device(self._device)
with (
torch.inference_mode(),
torch.autocast(device_type=policy_device.type)
if policy_device.type == "cuda" and self._policy.config.use_amp
else torch.inference_mode(),
):
autocast_ctx = (
torch.autocast(device_type=self._device.type)
if self._device.type == "cuda" and self._policy.config.use_amp
else nullcontext()
)
with torch.inference_mode(), autocast_ctx:
observation = prepare_observation_for_inference(
observation, policy_device, self._task, self._robot_type
observation, self._device, self._task, self._robot_type
)
observation = self._preprocessor(observation)
action = self._policy.select_action(observation)
+12 -1
View File
@@ -19,6 +19,7 @@ from __future__ import annotations
from collections import deque
import numpy as np
import torch
class RolloutRingBuffer:
@@ -28,6 +29,12 @@ class RolloutRingBuffer:
time (``max_frames``) and memory (``max_memory_bytes``). When either
limit is reached the oldest frames are evicted.
.. note::
This class is **single-threaded**. ``append``/``drain``/``clear``
must all be called from the same thread (the rollout main loop).
Concurrent access from a background thread will corrupt
``_current_bytes`` accounting.
Parameters
----------
max_seconds:
@@ -91,7 +98,11 @@ def _estimate_frame_bytes(frame: dict) -> int:
"""Rough byte estimate for a single frame dictionary."""
total = 0
for v in frame.values():
if isinstance(v, np.ndarray) or hasattr(v, "nbytes"):
if isinstance(v, torch.Tensor):
# ``torch.Tensor`` has no ``nbytes``; compute it explicitly so the
# memory cap is honoured even when frames hold unconverted tensors.
total += v.nelement() * v.element_size()
elif isinstance(v, np.ndarray) or hasattr(v, "nbytes"):
total += v.nbytes
elif isinstance(v, (int, float)):
total += 8
+38 -36
View File
@@ -33,7 +33,7 @@ import contextlib
import enum
import logging
import time
from threading import Lock
from threading import Event, Lock
from typing import Any
import numpy as np
@@ -90,14 +90,16 @@ class DAggerEvents:
self._phase = DAggerPhase.AUTONOMOUS
self._pending_transition: str | None = None
# Episode-level flags (written by keyboard, consumed by main loop)
self.exit_early: bool = False
self.rerecord_episode: bool = False
self.stop_recording: bool = False
# Episode-level flags written by keyboard/pedal threads, consumed by
# the main loop. ``threading.Event`` gives us atomic set/clear/check
# semantics without taking ``self._lock``.
self.exit_early = Event()
self.rerecord_episode = Event()
self.stop_recording = Event()
# Reset-phase flags (simpler lifecycle, shared between threads)
self.in_reset: bool = False
self.start_next_episode: bool = False
# Reset-phase flags (simpler lifecycle, shared between threads).
self.in_reset = Event()
self.start_next_episode = Event()
# -- Thread-safe phase access ------------------------------------------
@@ -140,8 +142,8 @@ class DAggerEvents:
with self._lock:
self._phase = DAggerPhase.AUTONOMOUS
self._pending_transition = None
self.exit_early = False
self.rerecord_episode = False
self.exit_early.clear()
self.rerecord_episode.clear()
# ---------------------------------------------------------------------------
@@ -198,24 +200,24 @@ def _reset_loop(
"""Reset period where the human repositions the environment."""
logger.info("RESET — press any key to enable teleoperation")
events.in_reset = True
events.start_next_episode = False
events.in_reset.set()
events.start_next_episode.clear()
obs = robot.get_observation()
robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features}
_teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
while not events.start_next_episode and not events.stop_recording:
while not events.start_next_episode.is_set() and not events.stop_recording.is_set():
precise_sleep(0.05)
if events.stop_recording:
if events.stop_recording.is_set():
return
events.start_next_episode = False
events.start_next_episode.clear()
_teleop_disable_torque(teleop)
logger.info("Teleop enabled — press any key to start episode")
while not events.start_next_episode and not events.stop_recording:
while not events.start_next_episode.is_set() and not events.stop_recording.is_set():
loop_start = time.perf_counter()
obs = robot.get_observation()
action = teleop.get_action()
@@ -224,8 +226,8 @@ def _reset_loop(
robot.send_action(robot_action_to_send)
precise_sleep(1 / fps - (time.perf_counter() - loop_start))
events.in_reset = False
events.start_next_episode = False
events.in_reset.clear()
events.start_next_episode.clear()
events.reset_for_episode()
@@ -242,16 +244,16 @@ def _init_dagger_keyboard(events: DAggerEvents):
def on_press(key):
try:
if events.in_reset:
if events.in_reset.is_set():
if (
key in [keyboard.Key.space, keyboard.Key.right]
or hasattr(key, "char")
and key.char == "c"
):
events.start_next_episode = True
events.start_next_episode.set()
elif key == keyboard.Key.esc:
events.stop_recording = True
events.start_next_episode = True
events.stop_recording.set()
events.start_next_episode.set()
return
phase = events.phase
@@ -275,15 +277,15 @@ def _init_dagger_keyboard(events: DAggerEvents):
elif key == keyboard.Key.right:
logger.info("End episode")
events.exit_early = True
events.exit_early.set()
elif key == keyboard.Key.left:
logger.info("Re-record episode")
events.rerecord_episode = True
events.exit_early = True
events.rerecord_episode.set()
events.exit_early.set()
elif key == keyboard.Key.esc:
logger.info("Stop recording...")
events.stop_recording = True
events.exit_early = True
events.stop_recording.set()
events.exit_early.set()
except Exception as e:
logger.debug("Key error: %s", e)
@@ -301,8 +303,8 @@ def _dagger_pedal_callback(events: DAggerEvents):
def on_press(code: str) -> None:
if code not in _DAGGER_PEDAL_KEYS:
return
if events.in_reset:
events.start_next_episode = True
if events.in_reset.is_set():
events.start_next_episode.set()
return
phase = events.phase
if phase == DAggerPhase.CORRECTING:
@@ -362,22 +364,22 @@ class DAggerStrategy(RolloutStrategy):
with VideoEncodingManager(dataset):
try:
recorded = 0
while recorded < self.config.num_episodes and not events.stop_recording:
while recorded < self.config.num_episodes and not events.stop_recording.is_set():
log_say(f"Episode {dataset.num_episodes}", self.config.play_sounds)
self._run_episode(ctx)
if events.rerecord_episode:
if events.rerecord_episode.is_set():
log_say("Re-recording", self.config.play_sounds)
events.rerecord_episode = False
events.exit_early = False
events.rerecord_episode.clear()
events.exit_early.clear()
dataset.clear_episode_buffer()
continue
dataset.save_episode()
recorded += 1
if recorded < self.config.num_episodes and not events.stop_recording:
if recorded < self.config.num_episodes and not events.stop_recording.is_set():
_reset_loop(
ctx.hardware.robot_wrapper,
teleop,
@@ -448,8 +450,8 @@ class DAggerStrategy(RolloutStrategy):
while timestamp < self.config.episode_time_s:
loop_start = time.perf_counter()
if events.exit_early:
events.exit_early = False
if events.exit_early.is_set():
events.exit_early.clear()
break
transition = events.consume_transition()
@@ -118,6 +118,14 @@ class HighlightStrategy(RolloutStrategy):
action_frame = build_dataset_frame(features, action_dict, prefix=ACTION)
frame = {**obs_frame, **action_frame, "task": task_str}
# NOTE: ``is_set()`` then ``clear()`` is not atomic
# against the keyboard thread setting the flag again
# in between — but that is benign: we lose at most one
# toggle, processed on the next iteration. The
# ``_recording_live`` branch below is reached in the
# SAME iteration after ``clear()`` runs, so a frame
# finalised by ``save_episode()`` is never re-added to
# the next episode.
if self._save_requested.is_set():
self._save_requested.clear()
if not self._recording_live.is_set():
+8
View File
@@ -108,10 +108,18 @@ class SentryStrategy(RolloutStrategy):
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
action_frame = build_dataset_frame(features, action_dict, prefix=ACTION)
frame = {**obs_frame, **action_frame, "task": task_str}
# ``add_frame`` writes to the in-progress episode buffer; the
# background pusher only ever touches *finalised* episode
# artifacts on disk. The two operate on disjoint state, so
# ``add_frame`` does not need ``_episode_lock``.
dataset.add_frame(frame)
elapsed = time.perf_counter() - episode_start
if elapsed >= self.config.episode_duration_s:
# ``save_episode`` finalises the in-progress episode and
# flushes it to disk; ``_episode_lock`` serialises this with
# ``push_to_hub`` (run in the background executor) so the
# pusher never reads a half-written episode.
with self._episode_lock:
dataset.save_episode()
episodes_since_push += 1