|
|
|
@@ -112,6 +112,11 @@ class DAggerEvents:
|
|
|
|
|
# Session-level flags
|
|
|
|
|
self.stop_recording = Event()
|
|
|
|
|
self.upload_requested = Event()
|
|
|
|
|
# Set when operator presses success/failure key to end the current episode.
|
|
|
|
|
self.save_episode_requested = Event()
|
|
|
|
|
|
|
|
|
|
# Episode success labeling
|
|
|
|
|
self._episode_success: bool | None = None
|
|
|
|
|
|
|
|
|
|
# -- Thread-safe phase access ------------------------------------------
|
|
|
|
|
|
|
|
|
@@ -155,7 +160,26 @@ class DAggerEvents:
|
|
|
|
|
with self._lock:
|
|
|
|
|
self._phase = DAggerPhase.AUTONOMOUS
|
|
|
|
|
self._pending_transition = None
|
|
|
|
|
self._episode_success = None
|
|
|
|
|
self.upload_requested.clear()
|
|
|
|
|
self.save_episode_requested.clear()
|
|
|
|
|
|
|
|
|
|
def mark_success(self) -> None:
|
|
|
|
|
"""Mark the current episode as successful (called from input threads)."""
|
|
|
|
|
with self._lock:
|
|
|
|
|
self._episode_success = True
|
|
|
|
|
|
|
|
|
|
def mark_failure(self) -> None:
|
|
|
|
|
"""Mark the current episode as failed (called from input threads)."""
|
|
|
|
|
with self._lock:
|
|
|
|
|
self._episode_success = False
|
|
|
|
|
|
|
|
|
|
def consume_episode_success(self) -> bool | None:
|
|
|
|
|
"""Consume and reset the episode success label. Returns None if unlabeled."""
|
|
|
|
|
with self._lock:
|
|
|
|
|
result = self._episode_success
|
|
|
|
|
self._episode_success = None
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@@ -186,12 +210,20 @@ def _init_dagger_keyboard(events: DAggerEvents, cfg: DAggerKeyboardConfig):
|
|
|
|
|
events.request_transition(key_to_event[name])
|
|
|
|
|
if name == cfg.upload:
|
|
|
|
|
events.upload_requested.set()
|
|
|
|
|
if name == cfg.success:
|
|
|
|
|
events.mark_success()
|
|
|
|
|
events.save_episode_requested.set()
|
|
|
|
|
logger.info("Episode marked as SUCCESS — saving")
|
|
|
|
|
if name == cfg.failure:
|
|
|
|
|
events.mark_failure()
|
|
|
|
|
events.save_episode_requested.set()
|
|
|
|
|
logger.info("Episode marked as FAILURE — saving")
|
|
|
|
|
|
|
|
|
|
return create_key_listener(
|
|
|
|
|
dispatch,
|
|
|
|
|
controls_help=(
|
|
|
|
|
f"pause_resume='{cfg.pause_resume}', correction='{cfg.correction}', "
|
|
|
|
|
f"upload='{cfg.upload}', ESC=stop"
|
|
|
|
|
f"upload='{cfg.upload}', success='{cfg.success}', failure='{cfg.failure}', ESC=stop"
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
@@ -313,6 +345,32 @@ class DAggerStrategy(RolloutStrategy):
|
|
|
|
|
)
|
|
|
|
|
logger.info("DAgger strategy teardown complete")
|
|
|
|
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
|
|
|
# Episode success labeling
|
|
|
|
|
# ------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
def _stamp_episode_success(self, dataset) -> None:
|
|
|
|
|
"""Set next.success on the terminal frame based on operator label.
|
|
|
|
|
|
|
|
|
|
Called just before save_episode(). If the operator pressed the success
|
|
|
|
|
key during this episode, the last frame's next.success is set to True.
|
|
|
|
|
Otherwise all frames remain False (unlabeled = assumed failure).
|
|
|
|
|
"""
|
|
|
|
|
buf = dataset.writer.episode_buffer
|
|
|
|
|
if buf is None:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
success_buf = buf.get("next.success")
|
|
|
|
|
if not success_buf:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
label = self._events.consume_episode_success()
|
|
|
|
|
logger.info("_stamp_episode_success: label=%s, buffer_len=%d", label, len(success_buf))
|
|
|
|
|
|
|
|
|
|
if label:
|
|
|
|
|
success_buf[-1] = np.array([True], dtype=bool)
|
|
|
|
|
logger.info("Terminal frame stamped next.success=True")
|
|
|
|
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
|
|
|
# Continuous recording mode (record_autonomous=True)
|
|
|
|
|
# ------------------------------------------------------------------
|
|
|
|
@@ -350,7 +408,12 @@ class DAggerStrategy(RolloutStrategy):
|
|
|
|
|
episode_start = time.perf_counter()
|
|
|
|
|
episodes_since_push = 0
|
|
|
|
|
episode_duration_s = self._episode_duration_s
|
|
|
|
|
logger.info("DAgger continuous recording started (episode_duration=%.0fs)", episode_duration_s)
|
|
|
|
|
num_episodes = self.config.num_episodes
|
|
|
|
|
logger.info(
|
|
|
|
|
"DAgger continuous recording started (episode_duration=%.0fs, target=%s eps)",
|
|
|
|
|
episode_duration_s,
|
|
|
|
|
num_episodes if num_episodes is not None else "∞",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
with VideoEncodingManager(dataset):
|
|
|
|
|
try:
|
|
|
|
@@ -399,6 +462,7 @@ class DAggerStrategy(RolloutStrategy):
|
|
|
|
|
**action_frame,
|
|
|
|
|
"task": task_str,
|
|
|
|
|
"intervention": np.array([True], dtype=bool),
|
|
|
|
|
"next.success": np.array([False], dtype=bool),
|
|
|
|
|
}
|
|
|
|
|
dataset.add_frame(frame)
|
|
|
|
|
record_tick += 1
|
|
|
|
@@ -427,23 +491,32 @@ class DAggerStrategy(RolloutStrategy):
|
|
|
|
|
**action_frame,
|
|
|
|
|
"task": task_str,
|
|
|
|
|
"intervention": np.array([False], dtype=bool),
|
|
|
|
|
"next.success": np.array([False], dtype=bool),
|
|
|
|
|
}
|
|
|
|
|
dataset.add_frame(frame)
|
|
|
|
|
record_tick += 1
|
|
|
|
|
|
|
|
|
|
# Episode rotation derived from the video file-size target.
|
|
|
|
|
# Saving is deferred while a correction is ongoing so the
|
|
|
|
|
# episode boundary lands on a clean autonomous frame.
|
|
|
|
|
# Episode rotation: either the operator pressed success/failure,
|
|
|
|
|
# or the video file-size target was reached.
|
|
|
|
|
# Defer the save while a correction is ongoing so the episode
|
|
|
|
|
# boundary lands on a clean autonomous frame. The event stays
|
|
|
|
|
# set until we actually save, so it won't be lost.
|
|
|
|
|
manual_save = events.save_episode_requested.is_set()
|
|
|
|
|
|
|
|
|
|
elapsed = time.perf_counter() - episode_start
|
|
|
|
|
if elapsed >= episode_duration_s and phase != DAggerPhase.CORRECTING:
|
|
|
|
|
if (manual_save or elapsed >= episode_duration_s) and phase != DAggerPhase.CORRECTING:
|
|
|
|
|
if manual_save:
|
|
|
|
|
events.save_episode_requested.clear()
|
|
|
|
|
with self._episode_lock:
|
|
|
|
|
self._stamp_episode_success(dataset)
|
|
|
|
|
dataset.save_episode()
|
|
|
|
|
episodes_since_push += 1
|
|
|
|
|
self._needs_push.set()
|
|
|
|
|
save_reason = "manual save" if manual_save else f"elapsed {elapsed:.1f}s"
|
|
|
|
|
logger.info(
|
|
|
|
|
"Episode saved (total: %d, elapsed: %.1fs)",
|
|
|
|
|
"Episode saved (%s, total: %d)",
|
|
|
|
|
save_reason,
|
|
|
|
|
dataset.num_episodes,
|
|
|
|
|
elapsed,
|
|
|
|
|
)
|
|
|
|
|
log_say(f"Episode {dataset.num_episodes} saved", play_sounds)
|
|
|
|
|
|
|
|
|
@@ -451,6 +524,25 @@ class DAggerStrategy(RolloutStrategy):
|
|
|
|
|
self._background_push(dataset, cfg)
|
|
|
|
|
episodes_since_push = 0
|
|
|
|
|
|
|
|
|
|
if num_episodes is not None and dataset.num_episodes >= num_episodes:
|
|
|
|
|
logger.info("Target episode count reached (%d), stopping session", num_episodes)
|
|
|
|
|
log_say(f"All {num_episodes} episodes collected", play_sounds)
|
|
|
|
|
events.stop_recording.set()
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
# Pause after manual save: stop the policy, return robot to
|
|
|
|
|
# initial position, and wait for the operator to reset the
|
|
|
|
|
# environment and press SPACE.
|
|
|
|
|
if manual_save:
|
|
|
|
|
engine.pause()
|
|
|
|
|
events.phase = DAggerPhase.PAUSED
|
|
|
|
|
self._return_to_initial_position(ctx.hardware)
|
|
|
|
|
last_action = None
|
|
|
|
|
logger.info(
|
|
|
|
|
"Episode saved — paused for environment reset. Press SPACE to start next episode."
|
|
|
|
|
)
|
|
|
|
|
log_say("Reset the environment, then press space", play_sounds)
|
|
|
|
|
|
|
|
|
|
episode_start = time.perf_counter()
|
|
|
|
|
|
|
|
|
|
dt = time.perf_counter() - loop_start
|
|
|
|
@@ -465,10 +557,13 @@ class DAggerStrategy(RolloutStrategy):
|
|
|
|
|
logger.info("DAgger continuous control loop ended — pausing engine")
|
|
|
|
|
engine.pause()
|
|
|
|
|
with contextlib.suppress(Exception):
|
|
|
|
|
with self._episode_lock:
|
|
|
|
|
dataset.save_episode()
|
|
|
|
|
self._needs_push.set()
|
|
|
|
|
logger.info("Final in-progress episode saved")
|
|
|
|
|
buf = dataset.writer.episode_buffer
|
|
|
|
|
if buf and any(len(v) > 0 for v in buf.values() if isinstance(v, list)):
|
|
|
|
|
with self._episode_lock:
|
|
|
|
|
self._stamp_episode_success(dataset)
|
|
|
|
|
dataset.save_episode()
|
|
|
|
|
self._needs_push.set()
|
|
|
|
|
logger.info("Final in-progress episode saved")
|
|
|
|
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
|
|
|
# Corrections-only mode (record_autonomous=False)
|
|
|
|
@@ -540,6 +635,7 @@ class DAggerStrategy(RolloutStrategy):
|
|
|
|
|
# Correction ended -> save episode (blocking if not streaming)
|
|
|
|
|
if old_phase == DAggerPhase.CORRECTING and new_phase == DAggerPhase.PAUSED:
|
|
|
|
|
with self._episode_lock:
|
|
|
|
|
self._stamp_episode_success(dataset)
|
|
|
|
|
dataset.save_episode()
|
|
|
|
|
recorded += 1
|
|
|
|
|
self._needs_push.set()
|
|
|
|
@@ -581,6 +677,7 @@ class DAggerStrategy(RolloutStrategy):
|
|
|
|
|
**action_frame,
|
|
|
|
|
"task": task_str,
|
|
|
|
|
"intervention": np.array([True], dtype=bool),
|
|
|
|
|
"next.success": np.array([False], dtype=bool),
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
record_tick += 1
|
|
|
|
@@ -614,10 +711,13 @@ class DAggerStrategy(RolloutStrategy):
|
|
|
|
|
logger.info("DAgger corrections-only loop ended — pausing engine")
|
|
|
|
|
engine.pause()
|
|
|
|
|
with contextlib.suppress(Exception):
|
|
|
|
|
with self._episode_lock:
|
|
|
|
|
dataset.save_episode()
|
|
|
|
|
self._needs_push.set()
|
|
|
|
|
logger.info("Final in-progress episode saved")
|
|
|
|
|
buf = dataset.writer.episode_buffer
|
|
|
|
|
if buf and any(len(v) > 0 for v in buf.values() if isinstance(v, list)):
|
|
|
|
|
with self._episode_lock:
|
|
|
|
|
self._stamp_episode_success(dataset)
|
|
|
|
|
dataset.save_episode()
|
|
|
|
|
self._needs_push.set()
|
|
|
|
|
logger.info("Final in-progress episode saved")
|
|
|
|
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
|
|
|
# State-machine transition side-effects
|
|
|
|
|