mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-18 10:10:08 +00:00
address review
This commit is contained in:
@@ -68,12 +68,12 @@ The `lerobot-rollout --strategy.type=dagger` mode requires **teleoperators with
|
|||||||
|
|
||||||
## Script
|
## Script
|
||||||
|
|
||||||
Use `lerobot-rollout` with `--strategy.type=dagger` for HIL data collection. Toggle RTC with `--rtc.enabled=true`:
|
Use `lerobot-rollout` with `--strategy.type=dagger` for HIL data collection. Select the inference backend with `--inference.type=sync|rtc`:
|
||||||
|
|
||||||
| Mode | Flag | Models |
|
| Mode | Flag | Models |
|
||||||
| ------------------------ | -------------------- | --------------------- |
|
| ------------------------ | ---------------------- | --------------------- |
|
||||||
| Standard (default) | _(no flag needed)_ | ACT, Diffusion Policy |
|
| Standard (default) | _(no flag needed)_ | ACT, Diffusion Policy |
|
||||||
| Real-Time Chunking (RTC) | `--rtc.enabled=true` | Pi0, Pi0.5, SmolVLA |
|
| Real-Time Chunking (RTC) | `--inference.type=rtc` | Pi0, Pi0.5, SmolVLA |
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -122,10 +122,10 @@ For models with high inference latency, enable RTC for smooth execution:
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
lerobot-rollout --strategy.type=dagger \
|
lerobot-rollout --strategy.type=dagger \
|
||||||
--rtc.enabled=true \
|
--inference.type=rtc \
|
||||||
--rtc.execution_horizon=20 \
|
--inference.rtc.execution_horizon=20 \
|
||||||
--rtc.max_guidance_weight=5.0 \
|
--inference.rtc.max_guidance_weight=5.0 \
|
||||||
--rtc.prefix_attention_schedule=LINEAR \
|
--inference.rtc.prefix_attention_schedule=LINEAR \
|
||||||
--robot.type=bi_openarm_follower \
|
--robot.type=bi_openarm_follower \
|
||||||
--robot.left_arm_config.port=can1 \
|
--robot.left_arm_config.port=can1 \
|
||||||
--robot.left_arm_config.side=left \
|
--robot.left_arm_config.side=left \
|
||||||
|
|||||||
@@ -542,4 +542,4 @@ The `--strategy.type` flag selects the execution mode:
|
|||||||
- `highlight`: Ring buffer recording with keystroke save (useful for capturing interesting events)
|
- `highlight`: Ring buffer recording with keystroke save (useful for capturing interesting events)
|
||||||
- `dagger`: Human-in-the-loop data collection (see [HIL Data Collection](./hil_data_collection))
|
- `dagger`: Human-in-the-loop data collection (see [HIL Data Collection](./hil_data_collection))
|
||||||
|
|
||||||
All strategies support `--rtc.enabled=true` for smooth execution with slow VLA models (Pi0, Pi0.5, SmolVLA).
|
All strategies support `--inference.type=rtc` for smooth execution with slow VLA models (Pi0, Pi0.5, SmolVLA).
|
||||||
|
|||||||
+4
-4
@@ -34,7 +34,7 @@ pip install -e ".[smolvla]"
|
|||||||
|
|
||||||
### Using RTC with Pi0
|
### Using RTC with Pi0
|
||||||
|
|
||||||
You can use `lerobot-rollout --strategy.type=base --rtc.enabled=true` for RTC deployment on real robots.
|
You can use `lerobot-rollout --strategy.type=base --inference.type=rtc` for RTC deployment on real robots.
|
||||||
The snippet below provides a simplified pseudo-example of how RTC operates with Pi0 in your pipeline:
|
The snippet below provides a simplified pseudo-example of how RTC operates with Pi0 in your pipeline:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
@@ -140,9 +140,9 @@ The script generates a visualization of the denoising process, comparing standar
|
|||||||
lerobot-rollout \
|
lerobot-rollout \
|
||||||
--strategy.type=base \
|
--strategy.type=base \
|
||||||
--policy.path=${HF_USERNAME}/policy_repo_id \
|
--policy.path=${HF_USERNAME}/policy_repo_id \
|
||||||
--rtc.enabled=true \
|
--inference.type=rtc \
|
||||||
--rtc.execution_horizon=10 \
|
--inference.rtc.execution_horizon=10 \
|
||||||
--rtc.max_guidance_weight=10.0 \
|
--inference.rtc.max_guidance_weight=10.0 \
|
||||||
--robot.type=so100_follower \
|
--robot.type=so100_follower \
|
||||||
--robot.port=/dev/tty.usbmodem58FA0834591 \
|
--robot.port=/dev/tty.usbmodem58FA0834591 \
|
||||||
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||||
|
|||||||
@@ -284,7 +284,7 @@ python examples/rtc/eval_with_real_robot.py \
|
|||||||
--task="task_description" \
|
--task="task_description" \
|
||||||
--duration=1000 \
|
--duration=1000 \
|
||||||
--fps=30 \
|
--fps=30 \
|
||||||
--rtc.enabled=true
|
--inference.type=rtc
|
||||||
```
|
```
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|||||||
@@ -242,18 +242,25 @@ def build_rollout_context(
|
|||||||
}
|
}
|
||||||
action_features_hw = robot.action_features
|
action_features_hw = robot.action_features
|
||||||
|
|
||||||
dataset_features = combine_feature_dicts(
|
# The action side is always needed: sync inference reads action names from
|
||||||
aggregate_pipeline_dataset_features(
|
# ``dataset_features[ACTION]`` to map policy tensors back to robot actions.
|
||||||
pipeline=teleop_action_processor,
|
action_dataset_features = aggregate_pipeline_dataset_features(
|
||||||
initial_features=create_initial_features(action=action_features_hw),
|
pipeline=teleop_action_processor,
|
||||||
use_videos=cfg.dataset.video if cfg.dataset else True,
|
initial_features=create_initial_features(action=action_features_hw),
|
||||||
),
|
use_videos=cfg.dataset.video if cfg.dataset else True,
|
||||||
aggregate_pipeline_dataset_features(
|
)
|
||||||
|
# Observation-side aggregation only feeds the dataset schema; skip it for
|
||||||
|
# the base strategy to avoid running observation pipelines that may have
|
||||||
|
# dataset-specific dependencies.
|
||||||
|
if cfg.dataset is not None:
|
||||||
|
observation_dataset_features = aggregate_pipeline_dataset_features(
|
||||||
pipeline=robot_observation_processor,
|
pipeline=robot_observation_processor,
|
||||||
initial_features=create_initial_features(observation=observation_features_hw),
|
initial_features=create_initial_features(observation=observation_features_hw),
|
||||||
use_videos=cfg.dataset.video if cfg.dataset else True,
|
use_videos=cfg.dataset.video,
|
||||||
),
|
)
|
||||||
)
|
dataset_features = combine_feature_dicts(action_dataset_features, observation_dataset_features)
|
||||||
|
else:
|
||||||
|
dataset_features = action_dataset_features
|
||||||
hw_features = hw_to_dataset_features(observation_features_hw, "observation")
|
hw_features = hw_to_dataset_features(observation_features_hw, "observation")
|
||||||
raw_action_keys = list(robot.action_features.keys())
|
raw_action_keys = list(robot.action_features.keys())
|
||||||
policy_action_names = getattr(policy_config, "action_feature_names", None)
|
policy_action_names = getattr(policy_config, "action_feature_names", None)
|
||||||
|
|||||||
@@ -68,6 +68,9 @@ class SyncInferenceConfig(InferenceStrategyConfig):
|
|||||||
class RTCInferenceConfig(InferenceStrategyConfig):
|
class RTCInferenceConfig(InferenceStrategyConfig):
|
||||||
"""Real-Time Chunking: async policy inference in a background thread."""
|
"""Real-Time Chunking: async policy inference in a background thread."""
|
||||||
|
|
||||||
|
# ``RTCConfig`` is a small dataclass with default-only fields, so eagerly
|
||||||
|
# constructing one here costs nothing and keeps draccus' CLI surface flat
|
||||||
|
# (``--inference.rtc.execution_horizon=...`` etc.). No need to lazy-init.
|
||||||
rtc: RTCConfig = field(default_factory=RTCConfig)
|
rtc: RTCConfig = field(default_factory=RTCConfig)
|
||||||
queue_threshold: int = 30
|
queue_threshold: int = 30
|
||||||
|
|
||||||
|
|||||||
@@ -51,6 +51,15 @@ from .base import InferenceStrategy
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# How long the RTC loop sleeps when paused, idle, or backpressured by a full queue.
|
||||||
|
_RTC_IDLE_SLEEP_S: float = 0.01
|
||||||
|
# Backoff between transient inference errors (per consecutive failure).
|
||||||
|
_RTC_ERROR_RETRY_DELAY_S: float = 0.5
|
||||||
|
# Consecutive transient errors tolerated before giving up and propagating shutdown.
|
||||||
|
_RTC_MAX_CONSECUTIVE_ERRORS: int = 10
|
||||||
|
# Hard timeout for joining the RTC thread on stop().
|
||||||
|
_RTC_JOIN_TIMEOUT_S: float = 3.0
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# RTC helpers (extracted from examples/rtc and examples/hil)
|
# RTC helpers (extracted from examples/rtc and examples/hil)
|
||||||
@@ -208,7 +217,7 @@ class RTCInferenceStrategy(InferenceStrategy):
|
|||||||
self._shutdown_event.set()
|
self._shutdown_event.set()
|
||||||
self._policy_active.clear()
|
self._policy_active.clear()
|
||||||
if self._rtc_thread is not None and self._rtc_thread.is_alive():
|
if self._rtc_thread is not None and self._rtc_thread.is_alive():
|
||||||
self._rtc_thread.join(timeout=3.0)
|
self._rtc_thread.join(timeout=_RTC_JOIN_TIMEOUT_S)
|
||||||
self._rtc_thread = None
|
self._rtc_thread = None
|
||||||
|
|
||||||
def pause(self) -> None:
|
def pause(self) -> None:
|
||||||
@@ -252,17 +261,18 @@ class RTCInferenceStrategy(InferenceStrategy):
|
|||||||
|
|
||||||
warmup_required = max(1, self._compile_warmup_inferences) if self._use_torch_compile else 0
|
warmup_required = max(1, self._compile_warmup_inferences) if self._use_torch_compile else 0
|
||||||
inference_count = 0
|
inference_count = 0
|
||||||
|
consecutive_errors = 0
|
||||||
|
|
||||||
while not self._shutdown_event.is_set():
|
while not self._shutdown_event.is_set():
|
||||||
if not self._policy_active.is_set():
|
if not self._policy_active.is_set():
|
||||||
time.sleep(0.01)
|
time.sleep(_RTC_IDLE_SLEEP_S)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
queue = self._action_queue
|
queue = self._action_queue
|
||||||
with self._obs_lock:
|
with self._obs_lock:
|
||||||
obs = self._obs_holder.get("obs")
|
obs = self._obs_holder.get("obs")
|
||||||
if queue is None or obs is None:
|
if queue is None or obs is None:
|
||||||
time.sleep(0.01)
|
time.sleep(_RTC_IDLE_SLEEP_S)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if queue.qsize() <= self._rtc_queue_threshold:
|
if queue.qsize() <= self._rtc_queue_threshold:
|
||||||
@@ -310,6 +320,7 @@ class RTCInferenceStrategy(InferenceStrategy):
|
|||||||
new_delay = math.ceil(new_latency / time_per_chunk)
|
new_delay = math.ceil(new_latency / time_per_chunk)
|
||||||
|
|
||||||
inference_count += 1
|
inference_count += 1
|
||||||
|
consecutive_errors = 0
|
||||||
is_warmup = self._use_torch_compile and inference_count <= warmup_required
|
is_warmup = self._use_torch_compile and inference_count <= warmup_required
|
||||||
if is_warmup:
|
if is_warmup:
|
||||||
latency_tracker.reset()
|
latency_tracker.reset()
|
||||||
@@ -329,11 +340,20 @@ class RTCInferenceStrategy(InferenceStrategy):
|
|||||||
logger.debug("RTC inference latency=%.2fs, queue=%d", new_latency, queue.qsize())
|
logger.debug("RTC inference latency=%.2fs, queue=%d", new_latency, queue.qsize())
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("RTC inference error: %s", e)
|
consecutive_errors += 1
|
||||||
|
logger.error(
|
||||||
|
"RTC inference error (%d/%d): %s",
|
||||||
|
consecutive_errors,
|
||||||
|
_RTC_MAX_CONSECUTIVE_ERRORS,
|
||||||
|
e,
|
||||||
|
)
|
||||||
logger.debug(traceback.format_exc())
|
logger.debug(traceback.format_exc())
|
||||||
time.sleep(0.5)
|
if consecutive_errors >= _RTC_MAX_CONSECUTIVE_ERRORS:
|
||||||
|
# Persistent failure: stop retrying and propagate shutdown.
|
||||||
|
raise
|
||||||
|
time.sleep(_RTC_ERROR_RETRY_DELAY_S)
|
||||||
else:
|
else:
|
||||||
time.sleep(0.01)
|
time.sleep(_RTC_IDLE_SLEEP_S)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Fatal error in RTC thread: %s", e)
|
logger.error("Fatal error in RTC thread: %s", e)
|
||||||
|
|||||||
@@ -17,6 +17,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from contextlib import nullcontext
|
||||||
from copy import copy
|
from copy import copy
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -55,7 +56,7 @@ class SyncInferenceStrategy(InferenceStrategy):
|
|||||||
self._dataset_features = dataset_features
|
self._dataset_features = dataset_features
|
||||||
self._ordered_action_keys = ordered_action_keys
|
self._ordered_action_keys = ordered_action_keys
|
||||||
self._task = task
|
self._task = task
|
||||||
self._device = device or "cpu"
|
self._device = torch.device(device or "cpu")
|
||||||
self._robot_type = robot_type
|
self._robot_type = robot_type
|
||||||
|
|
||||||
def start(self) -> None:
|
def start(self) -> None:
|
||||||
@@ -72,16 +73,18 @@ class SyncInferenceStrategy(InferenceStrategy):
|
|||||||
def get_action(self, obs_frame: dict | None) -> torch.Tensor | None:
|
def get_action(self, obs_frame: dict | None) -> torch.Tensor | None:
|
||||||
if obs_frame is None:
|
if obs_frame is None:
|
||||||
return None
|
return None
|
||||||
|
# Shallow copy is intentional: the caller (`send_next_action`) builds
|
||||||
|
# ``obs_frame`` fresh per tick via ``build_dataset_frame``, so the
|
||||||
|
# tensor/array values are not shared with any other reader.
|
||||||
observation = copy(obs_frame)
|
observation = copy(obs_frame)
|
||||||
policy_device = torch.device(self._device)
|
autocast_ctx = (
|
||||||
with (
|
torch.autocast(device_type=self._device.type)
|
||||||
torch.inference_mode(),
|
if self._device.type == "cuda" and self._policy.config.use_amp
|
||||||
torch.autocast(device_type=policy_device.type)
|
else nullcontext()
|
||||||
if policy_device.type == "cuda" and self._policy.config.use_amp
|
)
|
||||||
else torch.inference_mode(),
|
with torch.inference_mode(), autocast_ctx:
|
||||||
):
|
|
||||||
observation = prepare_observation_for_inference(
|
observation = prepare_observation_for_inference(
|
||||||
observation, policy_device, self._task, self._robot_type
|
observation, self._device, self._task, self._robot_type
|
||||||
)
|
)
|
||||||
observation = self._preprocessor(observation)
|
observation = self._preprocessor(observation)
|
||||||
action = self._policy.select_action(observation)
|
action = self._policy.select_action(observation)
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from __future__ import annotations
|
|||||||
from collections import deque
|
from collections import deque
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
class RolloutRingBuffer:
|
class RolloutRingBuffer:
|
||||||
@@ -28,6 +29,12 @@ class RolloutRingBuffer:
|
|||||||
time (``max_frames``) and memory (``max_memory_bytes``). When either
|
time (``max_frames``) and memory (``max_memory_bytes``). When either
|
||||||
limit is reached the oldest frames are evicted.
|
limit is reached the oldest frames are evicted.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
This class is **single-threaded**. ``append``/``drain``/``clear``
|
||||||
|
must all be called from the same thread (the rollout main loop).
|
||||||
|
Concurrent access from a background thread will corrupt
|
||||||
|
``_current_bytes`` accounting.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
max_seconds:
|
max_seconds:
|
||||||
@@ -91,7 +98,11 @@ def _estimate_frame_bytes(frame: dict) -> int:
|
|||||||
"""Rough byte estimate for a single frame dictionary."""
|
"""Rough byte estimate for a single frame dictionary."""
|
||||||
total = 0
|
total = 0
|
||||||
for v in frame.values():
|
for v in frame.values():
|
||||||
if isinstance(v, np.ndarray) or hasattr(v, "nbytes"):
|
if isinstance(v, torch.Tensor):
|
||||||
|
# ``torch.Tensor`` has no ``nbytes``; compute it explicitly so the
|
||||||
|
# memory cap is honoured even when frames hold unconverted tensors.
|
||||||
|
total += v.nelement() * v.element_size()
|
||||||
|
elif isinstance(v, np.ndarray) or hasattr(v, "nbytes"):
|
||||||
total += v.nbytes
|
total += v.nbytes
|
||||||
elif isinstance(v, (int, float)):
|
elif isinstance(v, (int, float)):
|
||||||
total += 8
|
total += 8
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ import contextlib
|
|||||||
import enum
|
import enum
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from threading import Lock
|
from threading import Event, Lock
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -90,14 +90,16 @@ class DAggerEvents:
|
|||||||
self._phase = DAggerPhase.AUTONOMOUS
|
self._phase = DAggerPhase.AUTONOMOUS
|
||||||
self._pending_transition: str | None = None
|
self._pending_transition: str | None = None
|
||||||
|
|
||||||
# Episode-level flags (written by keyboard, consumed by main loop)
|
# Episode-level flags written by keyboard/pedal threads, consumed by
|
||||||
self.exit_early: bool = False
|
# the main loop. ``threading.Event`` gives us atomic set/clear/check
|
||||||
self.rerecord_episode: bool = False
|
# semantics without taking ``self._lock``.
|
||||||
self.stop_recording: bool = False
|
self.exit_early = Event()
|
||||||
|
self.rerecord_episode = Event()
|
||||||
|
self.stop_recording = Event()
|
||||||
|
|
||||||
# Reset-phase flags (simpler lifecycle, shared between threads)
|
# Reset-phase flags (simpler lifecycle, shared between threads).
|
||||||
self.in_reset: bool = False
|
self.in_reset = Event()
|
||||||
self.start_next_episode: bool = False
|
self.start_next_episode = Event()
|
||||||
|
|
||||||
# -- Thread-safe phase access ------------------------------------------
|
# -- Thread-safe phase access ------------------------------------------
|
||||||
|
|
||||||
@@ -140,8 +142,8 @@ class DAggerEvents:
|
|||||||
with self._lock:
|
with self._lock:
|
||||||
self._phase = DAggerPhase.AUTONOMOUS
|
self._phase = DAggerPhase.AUTONOMOUS
|
||||||
self._pending_transition = None
|
self._pending_transition = None
|
||||||
self.exit_early = False
|
self.exit_early.clear()
|
||||||
self.rerecord_episode = False
|
self.rerecord_episode.clear()
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -198,24 +200,24 @@ def _reset_loop(
|
|||||||
"""Reset period where the human repositions the environment."""
|
"""Reset period where the human repositions the environment."""
|
||||||
logger.info("RESET — press any key to enable teleoperation")
|
logger.info("RESET — press any key to enable teleoperation")
|
||||||
|
|
||||||
events.in_reset = True
|
events.in_reset.set()
|
||||||
events.start_next_episode = False
|
events.start_next_episode.clear()
|
||||||
|
|
||||||
obs = robot.get_observation()
|
obs = robot.get_observation()
|
||||||
robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features}
|
robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features}
|
||||||
_teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
|
_teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
|
||||||
|
|
||||||
while not events.start_next_episode and not events.stop_recording:
|
while not events.start_next_episode.is_set() and not events.stop_recording.is_set():
|
||||||
precise_sleep(0.05)
|
precise_sleep(0.05)
|
||||||
|
|
||||||
if events.stop_recording:
|
if events.stop_recording.is_set():
|
||||||
return
|
return
|
||||||
|
|
||||||
events.start_next_episode = False
|
events.start_next_episode.clear()
|
||||||
_teleop_disable_torque(teleop)
|
_teleop_disable_torque(teleop)
|
||||||
logger.info("Teleop enabled — press any key to start episode")
|
logger.info("Teleop enabled — press any key to start episode")
|
||||||
|
|
||||||
while not events.start_next_episode and not events.stop_recording:
|
while not events.start_next_episode.is_set() and not events.stop_recording.is_set():
|
||||||
loop_start = time.perf_counter()
|
loop_start = time.perf_counter()
|
||||||
obs = robot.get_observation()
|
obs = robot.get_observation()
|
||||||
action = teleop.get_action()
|
action = teleop.get_action()
|
||||||
@@ -224,8 +226,8 @@ def _reset_loop(
|
|||||||
robot.send_action(robot_action_to_send)
|
robot.send_action(robot_action_to_send)
|
||||||
precise_sleep(1 / fps - (time.perf_counter() - loop_start))
|
precise_sleep(1 / fps - (time.perf_counter() - loop_start))
|
||||||
|
|
||||||
events.in_reset = False
|
events.in_reset.clear()
|
||||||
events.start_next_episode = False
|
events.start_next_episode.clear()
|
||||||
events.reset_for_episode()
|
events.reset_for_episode()
|
||||||
|
|
||||||
|
|
||||||
@@ -242,16 +244,16 @@ def _init_dagger_keyboard(events: DAggerEvents):
|
|||||||
|
|
||||||
def on_press(key):
|
def on_press(key):
|
||||||
try:
|
try:
|
||||||
if events.in_reset:
|
if events.in_reset.is_set():
|
||||||
if (
|
if (
|
||||||
key in [keyboard.Key.space, keyboard.Key.right]
|
key in [keyboard.Key.space, keyboard.Key.right]
|
||||||
or hasattr(key, "char")
|
or hasattr(key, "char")
|
||||||
and key.char == "c"
|
and key.char == "c"
|
||||||
):
|
):
|
||||||
events.start_next_episode = True
|
events.start_next_episode.set()
|
||||||
elif key == keyboard.Key.esc:
|
elif key == keyboard.Key.esc:
|
||||||
events.stop_recording = True
|
events.stop_recording.set()
|
||||||
events.start_next_episode = True
|
events.start_next_episode.set()
|
||||||
return
|
return
|
||||||
|
|
||||||
phase = events.phase
|
phase = events.phase
|
||||||
@@ -275,15 +277,15 @@ def _init_dagger_keyboard(events: DAggerEvents):
|
|||||||
|
|
||||||
elif key == keyboard.Key.right:
|
elif key == keyboard.Key.right:
|
||||||
logger.info("End episode")
|
logger.info("End episode")
|
||||||
events.exit_early = True
|
events.exit_early.set()
|
||||||
elif key == keyboard.Key.left:
|
elif key == keyboard.Key.left:
|
||||||
logger.info("Re-record episode")
|
logger.info("Re-record episode")
|
||||||
events.rerecord_episode = True
|
events.rerecord_episode.set()
|
||||||
events.exit_early = True
|
events.exit_early.set()
|
||||||
elif key == keyboard.Key.esc:
|
elif key == keyboard.Key.esc:
|
||||||
logger.info("Stop recording...")
|
logger.info("Stop recording...")
|
||||||
events.stop_recording = True
|
events.stop_recording.set()
|
||||||
events.exit_early = True
|
events.exit_early.set()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug("Key error: %s", e)
|
logger.debug("Key error: %s", e)
|
||||||
|
|
||||||
@@ -301,8 +303,8 @@ def _dagger_pedal_callback(events: DAggerEvents):
|
|||||||
def on_press(code: str) -> None:
|
def on_press(code: str) -> None:
|
||||||
if code not in _DAGGER_PEDAL_KEYS:
|
if code not in _DAGGER_PEDAL_KEYS:
|
||||||
return
|
return
|
||||||
if events.in_reset:
|
if events.in_reset.is_set():
|
||||||
events.start_next_episode = True
|
events.start_next_episode.set()
|
||||||
return
|
return
|
||||||
phase = events.phase
|
phase = events.phase
|
||||||
if phase == DAggerPhase.CORRECTING:
|
if phase == DAggerPhase.CORRECTING:
|
||||||
@@ -362,22 +364,22 @@ class DAggerStrategy(RolloutStrategy):
|
|||||||
with VideoEncodingManager(dataset):
|
with VideoEncodingManager(dataset):
|
||||||
try:
|
try:
|
||||||
recorded = 0
|
recorded = 0
|
||||||
while recorded < self.config.num_episodes and not events.stop_recording:
|
while recorded < self.config.num_episodes and not events.stop_recording.is_set():
|
||||||
log_say(f"Episode {dataset.num_episodes}", self.config.play_sounds)
|
log_say(f"Episode {dataset.num_episodes}", self.config.play_sounds)
|
||||||
|
|
||||||
self._run_episode(ctx)
|
self._run_episode(ctx)
|
||||||
|
|
||||||
if events.rerecord_episode:
|
if events.rerecord_episode.is_set():
|
||||||
log_say("Re-recording", self.config.play_sounds)
|
log_say("Re-recording", self.config.play_sounds)
|
||||||
events.rerecord_episode = False
|
events.rerecord_episode.clear()
|
||||||
events.exit_early = False
|
events.exit_early.clear()
|
||||||
dataset.clear_episode_buffer()
|
dataset.clear_episode_buffer()
|
||||||
continue
|
continue
|
||||||
|
|
||||||
dataset.save_episode()
|
dataset.save_episode()
|
||||||
recorded += 1
|
recorded += 1
|
||||||
|
|
||||||
if recorded < self.config.num_episodes and not events.stop_recording:
|
if recorded < self.config.num_episodes and not events.stop_recording.is_set():
|
||||||
_reset_loop(
|
_reset_loop(
|
||||||
ctx.hardware.robot_wrapper,
|
ctx.hardware.robot_wrapper,
|
||||||
teleop,
|
teleop,
|
||||||
@@ -448,8 +450,8 @@ class DAggerStrategy(RolloutStrategy):
|
|||||||
while timestamp < self.config.episode_time_s:
|
while timestamp < self.config.episode_time_s:
|
||||||
loop_start = time.perf_counter()
|
loop_start = time.perf_counter()
|
||||||
|
|
||||||
if events.exit_early:
|
if events.exit_early.is_set():
|
||||||
events.exit_early = False
|
events.exit_early.clear()
|
||||||
break
|
break
|
||||||
|
|
||||||
transition = events.consume_transition()
|
transition = events.consume_transition()
|
||||||
|
|||||||
@@ -118,6 +118,14 @@ class HighlightStrategy(RolloutStrategy):
|
|||||||
action_frame = build_dataset_frame(features, action_dict, prefix=ACTION)
|
action_frame = build_dataset_frame(features, action_dict, prefix=ACTION)
|
||||||
frame = {**obs_frame, **action_frame, "task": task_str}
|
frame = {**obs_frame, **action_frame, "task": task_str}
|
||||||
|
|
||||||
|
# NOTE: ``is_set()`` then ``clear()`` is not atomic
|
||||||
|
# against the keyboard thread setting the flag again
|
||||||
|
# in between — but that is benign: we lose at most one
|
||||||
|
# toggle, processed on the next iteration. The
|
||||||
|
# ``_recording_live`` branch below is reached in the
|
||||||
|
# SAME iteration after ``clear()`` runs, so a frame
|
||||||
|
# finalised by ``save_episode()`` is never re-added to
|
||||||
|
# the next episode.
|
||||||
if self._save_requested.is_set():
|
if self._save_requested.is_set():
|
||||||
self._save_requested.clear()
|
self._save_requested.clear()
|
||||||
if not self._recording_live.is_set():
|
if not self._recording_live.is_set():
|
||||||
|
|||||||
@@ -108,10 +108,18 @@ class SentryStrategy(RolloutStrategy):
|
|||||||
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
|
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
|
||||||
action_frame = build_dataset_frame(features, action_dict, prefix=ACTION)
|
action_frame = build_dataset_frame(features, action_dict, prefix=ACTION)
|
||||||
frame = {**obs_frame, **action_frame, "task": task_str}
|
frame = {**obs_frame, **action_frame, "task": task_str}
|
||||||
|
# ``add_frame`` writes to the in-progress episode buffer; the
|
||||||
|
# background pusher only ever touches *finalised* episode
|
||||||
|
# artifacts on disk. The two operate on disjoint state, so
|
||||||
|
# ``add_frame`` does not need ``_episode_lock``.
|
||||||
dataset.add_frame(frame)
|
dataset.add_frame(frame)
|
||||||
|
|
||||||
elapsed = time.perf_counter() - episode_start
|
elapsed = time.perf_counter() - episode_start
|
||||||
if elapsed >= self.config.episode_duration_s:
|
if elapsed >= self.config.episode_duration_s:
|
||||||
|
# ``save_episode`` finalises the in-progress episode and
|
||||||
|
# flushes it to disk; ``_episode_lock`` serialises this with
|
||||||
|
# ``push_to_hub`` (run in the background executor) so the
|
||||||
|
# pusher never reads a half-written episode.
|
||||||
with self._episode_lock:
|
with self._episode_lock:
|
||||||
dataset.save_episode()
|
dataset.save_episode()
|
||||||
episodes_since_push += 1
|
episodes_since_push += 1
|
||||||
|
|||||||
Reference in New Issue
Block a user