mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 12:40:08 +00:00
address review
This commit is contained in:
@@ -68,12 +68,12 @@ The `lerobot-rollout --strategy.type=dagger` mode requires **teleoperators with
|
||||
|
||||
## Script
|
||||
|
||||
Use `lerobot-rollout` with `--strategy.type=dagger` for HIL data collection. Toggle RTC with `--rtc.enabled=true`:
|
||||
Use `lerobot-rollout` with `--strategy.type=dagger` for HIL data collection. Select the inference backend with `--inference.type=sync|rtc`:
|
||||
|
||||
| Mode | Flag | Models |
|
||||
| ------------------------ | -------------------- | --------------------- |
|
||||
| Standard (default) | _(no flag needed)_ | ACT, Diffusion Policy |
|
||||
| Real-Time Chunking (RTC) | `--rtc.enabled=true` | Pi0, Pi0.5, SmolVLA |
|
||||
| Mode | Flag | Models |
|
||||
| ------------------------ | ---------------------- | --------------------- |
|
||||
| Standard (default) | _(no flag needed)_ | ACT, Diffusion Policy |
|
||||
| Real-Time Chunking (RTC) | `--inference.type=rtc` | Pi0, Pi0.5, SmolVLA |
|
||||
|
||||
---
|
||||
|
||||
@@ -122,10 +122,10 @@ For models with high inference latency, enable RTC for smooth execution:
|
||||
|
||||
```bash
|
||||
lerobot-rollout --strategy.type=dagger \
|
||||
--rtc.enabled=true \
|
||||
--rtc.execution_horizon=20 \
|
||||
--rtc.max_guidance_weight=5.0 \
|
||||
--rtc.prefix_attention_schedule=LINEAR \
|
||||
--inference.type=rtc \
|
||||
--inference.rtc.execution_horizon=20 \
|
||||
--inference.rtc.max_guidance_weight=5.0 \
|
||||
--inference.rtc.prefix_attention_schedule=LINEAR \
|
||||
--robot.type=bi_openarm_follower \
|
||||
--robot.left_arm_config.port=can1 \
|
||||
--robot.left_arm_config.side=left \
|
||||
|
||||
@@ -542,4 +542,4 @@ The `--strategy.type` flag selects the execution mode:
|
||||
- `highlight`: Ring buffer recording with keystroke save (useful for capturing interesting events)
|
||||
- `dagger`: Human-in-the-loop data collection (see [HIL Data Collection](./hil_data_collection))
|
||||
|
||||
All strategies support `--rtc.enabled=true` for smooth execution with slow VLA models (Pi0, Pi0.5, SmolVLA).
|
||||
All strategies support `--inference.type=rtc` for smooth execution with slow VLA models (Pi0, Pi0.5, SmolVLA).
|
||||
|
||||
+4
-4
@@ -34,7 +34,7 @@ pip install -e ".[smolvla]"
|
||||
|
||||
### Using RTC with Pi0
|
||||
|
||||
You can use `lerobot-rollout --strategy.type=base --rtc.enabled=true` for RTC deployment on real robots.
|
||||
You can use `lerobot-rollout --strategy.type=base --inference.type=rtc` for RTC deployment on real robots.
|
||||
The snippet below provides a simplified pseudo-example of how RTC operates with Pi0 in your pipeline:
|
||||
|
||||
```python
|
||||
@@ -140,9 +140,9 @@ The script generates a visualization of the denoising process, comparing standar
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
--policy.path=${HF_USERNAME}/policy_repo_id \
|
||||
--rtc.enabled=true \
|
||||
--rtc.execution_horizon=10 \
|
||||
--rtc.max_guidance_weight=10.0 \
|
||||
--inference.type=rtc \
|
||||
--inference.rtc.execution_horizon=10 \
|
||||
--inference.rtc.max_guidance_weight=10.0 \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58FA0834591 \
|
||||
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
|
||||
@@ -284,7 +284,7 @@ python examples/rtc/eval_with_real_robot.py \
|
||||
--task="task_description" \
|
||||
--duration=1000 \
|
||||
--fps=30 \
|
||||
--rtc.enabled=true
|
||||
--inference.type=rtc
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
@@ -242,18 +242,25 @@ def build_rollout_context(
|
||||
}
|
||||
action_features_hw = robot.action_features
|
||||
|
||||
dataset_features = combine_feature_dicts(
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=teleop_action_processor,
|
||||
initial_features=create_initial_features(action=action_features_hw),
|
||||
use_videos=cfg.dataset.video if cfg.dataset else True,
|
||||
),
|
||||
aggregate_pipeline_dataset_features(
|
||||
# The action side is always needed: sync inference reads action names from
|
||||
# ``dataset_features[ACTION]`` to map policy tensors back to robot actions.
|
||||
action_dataset_features = aggregate_pipeline_dataset_features(
|
||||
pipeline=teleop_action_processor,
|
||||
initial_features=create_initial_features(action=action_features_hw),
|
||||
use_videos=cfg.dataset.video if cfg.dataset else True,
|
||||
)
|
||||
# Observation-side aggregation only feeds the dataset schema; skip it for
|
||||
# the base strategy to avoid running observation pipelines that may have
|
||||
# dataset-specific dependencies.
|
||||
if cfg.dataset is not None:
|
||||
observation_dataset_features = aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_observation_processor,
|
||||
initial_features=create_initial_features(observation=observation_features_hw),
|
||||
use_videos=cfg.dataset.video if cfg.dataset else True,
|
||||
),
|
||||
)
|
||||
use_videos=cfg.dataset.video,
|
||||
)
|
||||
dataset_features = combine_feature_dicts(action_dataset_features, observation_dataset_features)
|
||||
else:
|
||||
dataset_features = action_dataset_features
|
||||
hw_features = hw_to_dataset_features(observation_features_hw, "observation")
|
||||
raw_action_keys = list(robot.action_features.keys())
|
||||
policy_action_names = getattr(policy_config, "action_feature_names", None)
|
||||
|
||||
@@ -68,6 +68,9 @@ class SyncInferenceConfig(InferenceStrategyConfig):
|
||||
class RTCInferenceConfig(InferenceStrategyConfig):
|
||||
"""Real-Time Chunking: async policy inference in a background thread."""
|
||||
|
||||
# ``RTCConfig`` is a small dataclass with default-only fields, so eagerly
|
||||
# constructing one here costs nothing and keeps draccus' CLI surface flat
|
||||
# (``--inference.rtc.execution_horizon=...`` etc.). No need to lazy-init.
|
||||
rtc: RTCConfig = field(default_factory=RTCConfig)
|
||||
queue_threshold: int = 30
|
||||
|
||||
|
||||
@@ -51,6 +51,15 @@ from .base import InferenceStrategy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# How long the RTC loop sleeps when paused, idle, or backpressured by a full queue.
|
||||
_RTC_IDLE_SLEEP_S: float = 0.01
|
||||
# Backoff between transient inference errors (per consecutive failure).
|
||||
_RTC_ERROR_RETRY_DELAY_S: float = 0.5
|
||||
# Consecutive transient errors tolerated before giving up and propagating shutdown.
|
||||
_RTC_MAX_CONSECUTIVE_ERRORS: int = 10
|
||||
# Hard timeout for joining the RTC thread on stop().
|
||||
_RTC_JOIN_TIMEOUT_S: float = 3.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RTC helpers (extracted from examples/rtc and examples/hil)
|
||||
@@ -208,7 +217,7 @@ class RTCInferenceStrategy(InferenceStrategy):
|
||||
self._shutdown_event.set()
|
||||
self._policy_active.clear()
|
||||
if self._rtc_thread is not None and self._rtc_thread.is_alive():
|
||||
self._rtc_thread.join(timeout=3.0)
|
||||
self._rtc_thread.join(timeout=_RTC_JOIN_TIMEOUT_S)
|
||||
self._rtc_thread = None
|
||||
|
||||
def pause(self) -> None:
|
||||
@@ -252,17 +261,18 @@ class RTCInferenceStrategy(InferenceStrategy):
|
||||
|
||||
warmup_required = max(1, self._compile_warmup_inferences) if self._use_torch_compile else 0
|
||||
inference_count = 0
|
||||
consecutive_errors = 0
|
||||
|
||||
while not self._shutdown_event.is_set():
|
||||
if not self._policy_active.is_set():
|
||||
time.sleep(0.01)
|
||||
time.sleep(_RTC_IDLE_SLEEP_S)
|
||||
continue
|
||||
|
||||
queue = self._action_queue
|
||||
with self._obs_lock:
|
||||
obs = self._obs_holder.get("obs")
|
||||
if queue is None or obs is None:
|
||||
time.sleep(0.01)
|
||||
time.sleep(_RTC_IDLE_SLEEP_S)
|
||||
continue
|
||||
|
||||
if queue.qsize() <= self._rtc_queue_threshold:
|
||||
@@ -310,6 +320,7 @@ class RTCInferenceStrategy(InferenceStrategy):
|
||||
new_delay = math.ceil(new_latency / time_per_chunk)
|
||||
|
||||
inference_count += 1
|
||||
consecutive_errors = 0
|
||||
is_warmup = self._use_torch_compile and inference_count <= warmup_required
|
||||
if is_warmup:
|
||||
latency_tracker.reset()
|
||||
@@ -329,11 +340,20 @@ class RTCInferenceStrategy(InferenceStrategy):
|
||||
logger.debug("RTC inference latency=%.2fs, queue=%d", new_latency, queue.qsize())
|
||||
|
||||
except Exception as e:
|
||||
logger.error("RTC inference error: %s", e)
|
||||
consecutive_errors += 1
|
||||
logger.error(
|
||||
"RTC inference error (%d/%d): %s",
|
||||
consecutive_errors,
|
||||
_RTC_MAX_CONSECUTIVE_ERRORS,
|
||||
e,
|
||||
)
|
||||
logger.debug(traceback.format_exc())
|
||||
time.sleep(0.5)
|
||||
if consecutive_errors >= _RTC_MAX_CONSECUTIVE_ERRORS:
|
||||
# Persistent failure: stop retrying and propagate shutdown.
|
||||
raise
|
||||
time.sleep(_RTC_ERROR_RETRY_DELAY_S)
|
||||
else:
|
||||
time.sleep(0.01)
|
||||
time.sleep(_RTC_IDLE_SLEEP_S)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Fatal error in RTC thread: %s", e)
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from contextlib import nullcontext
|
||||
from copy import copy
|
||||
|
||||
import torch
|
||||
@@ -55,7 +56,7 @@ class SyncInferenceStrategy(InferenceStrategy):
|
||||
self._dataset_features = dataset_features
|
||||
self._ordered_action_keys = ordered_action_keys
|
||||
self._task = task
|
||||
self._device = device or "cpu"
|
||||
self._device = torch.device(device or "cpu")
|
||||
self._robot_type = robot_type
|
||||
|
||||
def start(self) -> None:
|
||||
@@ -72,16 +73,18 @@ class SyncInferenceStrategy(InferenceStrategy):
|
||||
def get_action(self, obs_frame: dict | None) -> torch.Tensor | None:
|
||||
if obs_frame is None:
|
||||
return None
|
||||
# Shallow copy is intentional: the caller (`send_next_action`) builds
|
||||
# ``obs_frame`` fresh per tick via ``build_dataset_frame``, so the
|
||||
# tensor/array values are not shared with any other reader.
|
||||
observation = copy(obs_frame)
|
||||
policy_device = torch.device(self._device)
|
||||
with (
|
||||
torch.inference_mode(),
|
||||
torch.autocast(device_type=policy_device.type)
|
||||
if policy_device.type == "cuda" and self._policy.config.use_amp
|
||||
else torch.inference_mode(),
|
||||
):
|
||||
autocast_ctx = (
|
||||
torch.autocast(device_type=self._device.type)
|
||||
if self._device.type == "cuda" and self._policy.config.use_amp
|
||||
else nullcontext()
|
||||
)
|
||||
with torch.inference_mode(), autocast_ctx:
|
||||
observation = prepare_observation_for_inference(
|
||||
observation, policy_device, self._task, self._robot_type
|
||||
observation, self._device, self._task, self._robot_type
|
||||
)
|
||||
observation = self._preprocessor(observation)
|
||||
action = self._policy.select_action(observation)
|
||||
|
||||
@@ -19,6 +19,7 @@ from __future__ import annotations
|
||||
from collections import deque
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
class RolloutRingBuffer:
|
||||
@@ -28,6 +29,12 @@ class RolloutRingBuffer:
|
||||
time (``max_frames``) and memory (``max_memory_bytes``). When either
|
||||
limit is reached the oldest frames are evicted.
|
||||
|
||||
.. note::
|
||||
This class is **single-threaded**. ``append``/``drain``/``clear``
|
||||
must all be called from the same thread (the rollout main loop).
|
||||
Concurrent access from a background thread will corrupt
|
||||
``_current_bytes`` accounting.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
max_seconds:
|
||||
@@ -91,7 +98,11 @@ def _estimate_frame_bytes(frame: dict) -> int:
|
||||
"""Rough byte estimate for a single frame dictionary."""
|
||||
total = 0
|
||||
for v in frame.values():
|
||||
if isinstance(v, np.ndarray) or hasattr(v, "nbytes"):
|
||||
if isinstance(v, torch.Tensor):
|
||||
# ``torch.Tensor`` has no ``nbytes``; compute it explicitly so the
|
||||
# memory cap is honoured even when frames hold unconverted tensors.
|
||||
total += v.nelement() * v.element_size()
|
||||
elif isinstance(v, np.ndarray) or hasattr(v, "nbytes"):
|
||||
total += v.nbytes
|
||||
elif isinstance(v, (int, float)):
|
||||
total += 8
|
||||
|
||||
@@ -33,7 +33,7 @@ import contextlib
|
||||
import enum
|
||||
import logging
|
||||
import time
|
||||
from threading import Lock
|
||||
from threading import Event, Lock
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
@@ -90,14 +90,16 @@ class DAggerEvents:
|
||||
self._phase = DAggerPhase.AUTONOMOUS
|
||||
self._pending_transition: str | None = None
|
||||
|
||||
# Episode-level flags (written by keyboard, consumed by main loop)
|
||||
self.exit_early: bool = False
|
||||
self.rerecord_episode: bool = False
|
||||
self.stop_recording: bool = False
|
||||
# Episode-level flags written by keyboard/pedal threads, consumed by
|
||||
# the main loop. ``threading.Event`` gives us atomic set/clear/check
|
||||
# semantics without taking ``self._lock``.
|
||||
self.exit_early = Event()
|
||||
self.rerecord_episode = Event()
|
||||
self.stop_recording = Event()
|
||||
|
||||
# Reset-phase flags (simpler lifecycle, shared between threads)
|
||||
self.in_reset: bool = False
|
||||
self.start_next_episode: bool = False
|
||||
# Reset-phase flags (simpler lifecycle, shared between threads).
|
||||
self.in_reset = Event()
|
||||
self.start_next_episode = Event()
|
||||
|
||||
# -- Thread-safe phase access ------------------------------------------
|
||||
|
||||
@@ -140,8 +142,8 @@ class DAggerEvents:
|
||||
with self._lock:
|
||||
self._phase = DAggerPhase.AUTONOMOUS
|
||||
self._pending_transition = None
|
||||
self.exit_early = False
|
||||
self.rerecord_episode = False
|
||||
self.exit_early.clear()
|
||||
self.rerecord_episode.clear()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -198,24 +200,24 @@ def _reset_loop(
|
||||
"""Reset period where the human repositions the environment."""
|
||||
logger.info("RESET — press any key to enable teleoperation")
|
||||
|
||||
events.in_reset = True
|
||||
events.start_next_episode = False
|
||||
events.in_reset.set()
|
||||
events.start_next_episode.clear()
|
||||
|
||||
obs = robot.get_observation()
|
||||
robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features}
|
||||
_teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
|
||||
|
||||
while not events.start_next_episode and not events.stop_recording:
|
||||
while not events.start_next_episode.is_set() and not events.stop_recording.is_set():
|
||||
precise_sleep(0.05)
|
||||
|
||||
if events.stop_recording:
|
||||
if events.stop_recording.is_set():
|
||||
return
|
||||
|
||||
events.start_next_episode = False
|
||||
events.start_next_episode.clear()
|
||||
_teleop_disable_torque(teleop)
|
||||
logger.info("Teleop enabled — press any key to start episode")
|
||||
|
||||
while not events.start_next_episode and not events.stop_recording:
|
||||
while not events.start_next_episode.is_set() and not events.stop_recording.is_set():
|
||||
loop_start = time.perf_counter()
|
||||
obs = robot.get_observation()
|
||||
action = teleop.get_action()
|
||||
@@ -224,8 +226,8 @@ def _reset_loop(
|
||||
robot.send_action(robot_action_to_send)
|
||||
precise_sleep(1 / fps - (time.perf_counter() - loop_start))
|
||||
|
||||
events.in_reset = False
|
||||
events.start_next_episode = False
|
||||
events.in_reset.clear()
|
||||
events.start_next_episode.clear()
|
||||
events.reset_for_episode()
|
||||
|
||||
|
||||
@@ -242,16 +244,16 @@ def _init_dagger_keyboard(events: DAggerEvents):
|
||||
|
||||
def on_press(key):
|
||||
try:
|
||||
if events.in_reset:
|
||||
if events.in_reset.is_set():
|
||||
if (
|
||||
key in [keyboard.Key.space, keyboard.Key.right]
|
||||
or hasattr(key, "char")
|
||||
and key.char == "c"
|
||||
):
|
||||
events.start_next_episode = True
|
||||
events.start_next_episode.set()
|
||||
elif key == keyboard.Key.esc:
|
||||
events.stop_recording = True
|
||||
events.start_next_episode = True
|
||||
events.stop_recording.set()
|
||||
events.start_next_episode.set()
|
||||
return
|
||||
|
||||
phase = events.phase
|
||||
@@ -275,15 +277,15 @@ def _init_dagger_keyboard(events: DAggerEvents):
|
||||
|
||||
elif key == keyboard.Key.right:
|
||||
logger.info("End episode")
|
||||
events.exit_early = True
|
||||
events.exit_early.set()
|
||||
elif key == keyboard.Key.left:
|
||||
logger.info("Re-record episode")
|
||||
events.rerecord_episode = True
|
||||
events.exit_early = True
|
||||
events.rerecord_episode.set()
|
||||
events.exit_early.set()
|
||||
elif key == keyboard.Key.esc:
|
||||
logger.info("Stop recording...")
|
||||
events.stop_recording = True
|
||||
events.exit_early = True
|
||||
events.stop_recording.set()
|
||||
events.exit_early.set()
|
||||
except Exception as e:
|
||||
logger.debug("Key error: %s", e)
|
||||
|
||||
@@ -301,8 +303,8 @@ def _dagger_pedal_callback(events: DAggerEvents):
|
||||
def on_press(code: str) -> None:
|
||||
if code not in _DAGGER_PEDAL_KEYS:
|
||||
return
|
||||
if events.in_reset:
|
||||
events.start_next_episode = True
|
||||
if events.in_reset.is_set():
|
||||
events.start_next_episode.set()
|
||||
return
|
||||
phase = events.phase
|
||||
if phase == DAggerPhase.CORRECTING:
|
||||
@@ -362,22 +364,22 @@ class DAggerStrategy(RolloutStrategy):
|
||||
with VideoEncodingManager(dataset):
|
||||
try:
|
||||
recorded = 0
|
||||
while recorded < self.config.num_episodes and not events.stop_recording:
|
||||
while recorded < self.config.num_episodes and not events.stop_recording.is_set():
|
||||
log_say(f"Episode {dataset.num_episodes}", self.config.play_sounds)
|
||||
|
||||
self._run_episode(ctx)
|
||||
|
||||
if events.rerecord_episode:
|
||||
if events.rerecord_episode.is_set():
|
||||
log_say("Re-recording", self.config.play_sounds)
|
||||
events.rerecord_episode = False
|
||||
events.exit_early = False
|
||||
events.rerecord_episode.clear()
|
||||
events.exit_early.clear()
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
dataset.save_episode()
|
||||
recorded += 1
|
||||
|
||||
if recorded < self.config.num_episodes and not events.stop_recording:
|
||||
if recorded < self.config.num_episodes and not events.stop_recording.is_set():
|
||||
_reset_loop(
|
||||
ctx.hardware.robot_wrapper,
|
||||
teleop,
|
||||
@@ -448,8 +450,8 @@ class DAggerStrategy(RolloutStrategy):
|
||||
while timestamp < self.config.episode_time_s:
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
if events.exit_early:
|
||||
events.exit_early = False
|
||||
if events.exit_early.is_set():
|
||||
events.exit_early.clear()
|
||||
break
|
||||
|
||||
transition = events.consume_transition()
|
||||
|
||||
@@ -118,6 +118,14 @@ class HighlightStrategy(RolloutStrategy):
|
||||
action_frame = build_dataset_frame(features, action_dict, prefix=ACTION)
|
||||
frame = {**obs_frame, **action_frame, "task": task_str}
|
||||
|
||||
# NOTE: ``is_set()`` then ``clear()`` is not atomic
|
||||
# against the keyboard thread setting the flag again
|
||||
# in between — but that is benign: we lose at most one
|
||||
# toggle, processed on the next iteration. The
|
||||
# ``_recording_live`` branch below is reached in the
|
||||
# SAME iteration after ``clear()`` runs, so a frame
|
||||
# finalised by ``save_episode()`` is never re-added to
|
||||
# the next episode.
|
||||
if self._save_requested.is_set():
|
||||
self._save_requested.clear()
|
||||
if not self._recording_live.is_set():
|
||||
|
||||
@@ -108,10 +108,18 @@ class SentryStrategy(RolloutStrategy):
|
||||
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
|
||||
action_frame = build_dataset_frame(features, action_dict, prefix=ACTION)
|
||||
frame = {**obs_frame, **action_frame, "task": task_str}
|
||||
# ``add_frame`` writes to the in-progress episode buffer; the
|
||||
# background pusher only ever touches *finalised* episode
|
||||
# artifacts on disk. The two operate on disjoint state, so
|
||||
# ``add_frame`` does not need ``_episode_lock``.
|
||||
dataset.add_frame(frame)
|
||||
|
||||
elapsed = time.perf_counter() - episode_start
|
||||
if elapsed >= self.config.episode_duration_s:
|
||||
# ``save_episode`` finalises the in-progress episode and
|
||||
# flushes it to disk; ``_episode_lock`` serialises this with
|
||||
# ``push_to_hub`` (run in the background executor) so the
|
||||
# pusher never reads a half-written episode.
|
||||
with self._episode_lock:
|
||||
dataset.save_episode()
|
||||
episodes_since_push += 1
|
||||
|
||||
Reference in New Issue
Block a user