feat(rollout): add episode success/failure labeling to DAgger strategy

Enable operators to mark episodes as success or failure during DAgger
data collection. Pressing 's' or 'f' immediately saves the episode
with the appropriate label and returns the robot to its initial position.

- Add success/failure key bindings to DAggerKeyboardConfig
- Add save_episode_requested event and episode_success state to DAggerEvents
- Stamp next.success=True on terminal frame for successful episodes
- Pause and return to initial position after manual save for env reset
- Add num_episodes target to stop continuous recording automatically
- Defer save during corrections to avoid splitting mid-intervention
This commit is contained in:
Khalil Meftah
2026-07-02 17:13:05 +02:00
parent 2f2b567951
commit 44f76dbbf0
3 changed files with 113 additions and 8 deletions
+6
View File
@@ -106,6 +106,8 @@ class DAggerKeyboardConfig:
pause_resume: str = "space"
correction: str = "tab"
upload: str = "enter"
success: str = "s"
failure: str = "f"
@dataclass
@@ -165,6 +167,10 @@ class DAggerStrategyConfig(RolloutStrategyConfig):
2. **correction** — toggle human correction recording.
3. **upload** — push dataset to hub on demand (corrections-only mode).
Episode success labeling:
4. **success** — mark current episode as successful.
5. **failure** — mark current episode as failed.
When ``record_autonomous=False`` (default) only human-correction windows
are recorded — each correction becomes its own episode. Set to ``True``
to record both autonomous and correction frames with size-based episode
+5
View File
@@ -350,6 +350,11 @@ def build_rollout_context(
"shape": (1,),
"names": None,
}
dataset_features["next.success"] = {
"dtype": "bool",
"shape": (1,),
"names": None,
}
repo_name = cfg.dataset.repo_id.split("/", 1)[-1]
if not repo_name.startswith("rollout_"):
+102 -8
View File
@@ -112,6 +112,11 @@ class DAggerEvents:
# Session-level flags
self.stop_recording = Event()
self.upload_requested = Event()
# Set when operator presses success/failure key to end the current episode.
self.save_episode_requested = Event()
# Episode success labeling
self._episode_success: bool | None = None
# -- Thread-safe phase access ------------------------------------------
@@ -155,7 +160,26 @@ class DAggerEvents:
with self._lock:
self._phase = DAggerPhase.AUTONOMOUS
self._pending_transition = None
self._episode_success = None
self.upload_requested.clear()
self.save_episode_requested.clear()
def mark_success(self) -> None:
"""Mark the current episode as successful (called from input threads)."""
with self._lock:
self._episode_success = True
def mark_failure(self) -> None:
"""Mark the current episode as failed (called from input threads)."""
with self._lock:
self._episode_success = False
def consume_episode_success(self) -> bool | None:
"""Consume and reset the episode success label. Returns None if unlabeled."""
with self._lock:
result = self._episode_success
self._episode_success = None
return result
# ---------------------------------------------------------------------------
@@ -186,12 +210,20 @@ def _init_dagger_keyboard(events: DAggerEvents, cfg: DAggerKeyboardConfig):
events.request_transition(key_to_event[name])
if name == cfg.upload:
events.upload_requested.set()
if name == cfg.success:
events.mark_success()
events.save_episode_requested.set()
logger.info("Episode marked as SUCCESS — saving")
if name == cfg.failure:
events.mark_failure()
events.save_episode_requested.set()
logger.info("Episode marked as FAILURE — saving")
return create_key_listener(
dispatch,
controls_help=(
f"pause_resume='{cfg.pause_resume}', correction='{cfg.correction}', "
f"upload='{cfg.upload}', ESC=stop"
f"upload='{cfg.upload}', success='{cfg.success}', failure='{cfg.failure}', ESC=stop"
),
)
@@ -313,6 +345,31 @@ class DAggerStrategy(RolloutStrategy):
)
logger.info("DAgger strategy teardown complete")
# ------------------------------------------------------------------
# Episode success labeling
# ------------------------------------------------------------------
def _stamp_episode_success(self, dataset) -> None:
"""Set next.success on the terminal frame based on operator label.
Called just before save_episode(). If the operator pressed the success
key during this episode, the last frame's next.success is set to True.
Otherwise all frames remain False (unlabeled = assumed failure).
"""
buf = dataset.writer.episode_buffer
if buf is None:
return
success_buf = buf.get("next.success")
if not success_buf:
return
label = self._events.consume_episode_success()
if label:
success_buf[-1] = np.array([True], dtype=bool)
logger.info("Terminal frame stamped next.success=True")
# ------------------------------------------------------------------
# Continuous recording mode (record_autonomous=True)
# ------------------------------------------------------------------
@@ -350,7 +407,12 @@ class DAggerStrategy(RolloutStrategy):
episode_start = time.perf_counter()
episodes_since_push = 0
episode_duration_s = self._episode_duration_s
logger.info("DAgger continuous recording started (episode_duration=%.0fs)", episode_duration_s)
num_episodes = self.config.num_episodes
logger.info(
"DAgger continuous recording started (episode_duration=%.0fs, target=%s eps)",
episode_duration_s,
num_episodes if num_episodes is not None else "",
)
with VideoEncodingManager(dataset):
try:
@@ -399,6 +461,7 @@ class DAggerStrategy(RolloutStrategy):
**action_frame,
"task": task_str,
"intervention": np.array([True], dtype=bool),
"next.success": np.array([False], dtype=bool),
}
dataset.add_frame(frame)
record_tick += 1
@@ -427,23 +490,32 @@ class DAggerStrategy(RolloutStrategy):
**action_frame,
"task": task_str,
"intervention": np.array([False], dtype=bool),
"next.success": np.array([False], dtype=bool),
}
dataset.add_frame(frame)
record_tick += 1
# Episode rotation derived from the video file-size target.
# Saving is deferred while a correction is ongoing so the
# episode boundary lands on a clean autonomous frame.
# Episode rotation: either the operator pressed success/failure,
# or the video file-size target was reached.
# Defer the save while a correction is ongoing so the episode
# boundary lands on a clean autonomous frame. The event stays
# set until we actually save, so it won't be lost.
manual_save = events.save_episode_requested.is_set()
elapsed = time.perf_counter() - episode_start
if elapsed >= episode_duration_s and phase != DAggerPhase.CORRECTING:
if (manual_save or elapsed >= episode_duration_s) and phase != DAggerPhase.CORRECTING:
if manual_save:
events.save_episode_requested.clear()
with self._episode_lock:
self._stamp_episode_success(dataset)
dataset.save_episode()
episodes_since_push += 1
self._needs_push.set()
save_reason = "manual save" if manual_save else f"elapsed {elapsed:.1f}s"
logger.info(
"Episode saved (total: %d, elapsed: %.1fs)",
"Episode saved (%s, total: %d)",
save_reason,
dataset.num_episodes,
elapsed,
)
log_say(f"Episode {dataset.num_episodes} saved", play_sounds)
@@ -451,6 +523,24 @@ class DAggerStrategy(RolloutStrategy):
self._background_push(dataset, cfg)
episodes_since_push = 0
if num_episodes is not None and dataset.num_episodes >= num_episodes:
logger.info("Target episode count reached (%d), stopping session", num_episodes)
log_say(f"All {num_episodes} episodes collected", play_sounds)
events.stop_recording.set()
break
# Pause after manual save: stop the policy, return robot to
# initial position, and wait for the operator to reset the
# environment and press SPACE.
if manual_save:
engine.pause()
events.phase = DAggerPhase.PAUSED
self._return_to_initial_position(ctx.hardware)
logger.info(
"Episode saved — paused for environment reset. Press SPACE to start next episode."
)
log_say("Reset the environment, then press space", play_sounds)
episode_start = time.perf_counter()
dt = time.perf_counter() - loop_start
@@ -466,6 +556,7 @@ class DAggerStrategy(RolloutStrategy):
engine.pause()
with contextlib.suppress(Exception):
with self._episode_lock:
self._stamp_episode_success(dataset)
dataset.save_episode()
self._needs_push.set()
logger.info("Final in-progress episode saved")
@@ -540,6 +631,7 @@ class DAggerStrategy(RolloutStrategy):
# Correction ended -> save episode (blocking if not streaming)
if old_phase == DAggerPhase.CORRECTING and new_phase == DAggerPhase.PAUSED:
with self._episode_lock:
self._stamp_episode_success(dataset)
dataset.save_episode()
recorded += 1
self._needs_push.set()
@@ -581,6 +673,7 @@ class DAggerStrategy(RolloutStrategy):
**action_frame,
"task": task_str,
"intervention": np.array([True], dtype=bool),
"next.success": np.array([False], dtype=bool),
}
)
record_tick += 1
@@ -615,6 +708,7 @@ class DAggerStrategy(RolloutStrategy):
engine.pause()
with contextlib.suppress(Exception):
with self._episode_lock:
self._stamp_episode_success(dataset)
dataset.save_episode()
self._needs_push.set()
logger.info("Final in-progress episode saved")