Compare commits

...

5 Commits

Author SHA1 Message Date
Khalil Meftah 4af7095693 Merge branch 'main' into feat/rollout/dagger-episode-save 2026-07-03 16:50:10 +02:00
Khalil Meftah 46d4ddc698 chore(rollout): log episode success label and buffer length 2026-07-02 19:12:10 +02:00
Khalil Meftah b29ba27977 fix(rollout): guard empty buffer save 2026-07-02 18:02:59 +02:00
Khalil Meftah 599e2432e5 fix(rollout): clear last_action after return_to_initial 2026-07-02 18:02:36 +02:00
Khalil Meftah 44f76dbbf0 feat(rollout): add episode success/failure labeling to DAgger strategy
Enable operators to mark episodes as success or failure during DAgger
data collection. Pressing 's' or 'f' immediately saves the episode
with the appropriate label and returns the robot to its initial position.

- Add success/failure key bindings to DAggerKeyboardConfig
- Add save_episode_requested event and episode_success state to DAggerEvents
- Stamp next.success=True on terminal frame for successful episodes
- Pause and return to initial position after manual save for env reset
- Add num_episodes target to stop continuous recording automatically
- Defer save during corrections to avoid splitting mid-intervention
2026-07-02 17:48:02 +02:00
3 changed files with 127 additions and 16 deletions
+6
View File
@@ -106,6 +106,8 @@ class DAggerKeyboardConfig:
pause_resume: str = "space"
correction: str = "tab"
upload: str = "enter"
success: str = "s"
failure: str = "f"
@dataclass
@@ -165,6 +167,10 @@ class DAggerStrategyConfig(RolloutStrategyConfig):
2. **correction** — toggle human correction recording.
3. **upload** — push dataset to hub on demand (corrections-only mode).
Episode success labeling:
4. **success** — mark current episode as successful.
5. **failure** — mark current episode as failed.
When ``record_autonomous=False`` (default) only human-correction windows
are recorded — each correction becomes its own episode. Set to ``True``
to record both autonomous and correction frames with size-based episode
+5
View File
@@ -350,6 +350,11 @@ def build_rollout_context(
"shape": (1,),
"names": None,
}
dataset_features["next.success"] = {
"dtype": "bool",
"shape": (1,),
"names": None,
}
repo_name = cfg.dataset.repo_id.split("/", 1)[-1]
if not repo_name.startswith("rollout_"):
+116 -16
View File
@@ -112,6 +112,11 @@ class DAggerEvents:
# Session-level flags
self.stop_recording = Event()
self.upload_requested = Event()
# Set when operator presses success/failure key to end the current episode.
self.save_episode_requested = Event()
# Episode success labeling
self._episode_success: bool | None = None
# -- Thread-safe phase access ------------------------------------------
@@ -155,7 +160,26 @@ class DAggerEvents:
with self._lock:
self._phase = DAggerPhase.AUTONOMOUS
self._pending_transition = None
self._episode_success = None
self.upload_requested.clear()
self.save_episode_requested.clear()
def mark_success(self) -> None:
"""Mark the current episode as successful (called from input threads)."""
with self._lock:
self._episode_success = True
def mark_failure(self) -> None:
"""Mark the current episode as failed (called from input threads)."""
with self._lock:
self._episode_success = False
def consume_episode_success(self) -> bool | None:
"""Consume and reset the episode success label. Returns None if unlabeled."""
with self._lock:
result = self._episode_success
self._episode_success = None
return result
# ---------------------------------------------------------------------------
@@ -186,12 +210,20 @@ def _init_dagger_keyboard(events: DAggerEvents, cfg: DAggerKeyboardConfig):
events.request_transition(key_to_event[name])
if name == cfg.upload:
events.upload_requested.set()
if name == cfg.success:
events.mark_success()
events.save_episode_requested.set()
logger.info("Episode marked as SUCCESS — saving")
if name == cfg.failure:
events.mark_failure()
events.save_episode_requested.set()
logger.info("Episode marked as FAILURE — saving")
return create_key_listener(
dispatch,
controls_help=(
f"pause_resume='{cfg.pause_resume}', correction='{cfg.correction}', "
f"upload='{cfg.upload}', ESC=stop"
f"upload='{cfg.upload}', success='{cfg.success}', failure='{cfg.failure}', ESC=stop"
),
)
@@ -313,6 +345,32 @@ class DAggerStrategy(RolloutStrategy):
)
logger.info("DAgger strategy teardown complete")
# ------------------------------------------------------------------
# Episode success labeling
# ------------------------------------------------------------------
def _stamp_episode_success(self, dataset) -> None:
"""Set next.success on the terminal frame based on operator label.
Called just before save_episode(). If the operator pressed the success
key during this episode, the last frame's next.success is set to True.
Otherwise all frames remain False (unlabeled = assumed failure).
"""
buf = dataset.writer.episode_buffer
if buf is None:
return
success_buf = buf.get("next.success")
if not success_buf:
return
label = self._events.consume_episode_success()
logger.info("_stamp_episode_success: label=%s, buffer_len=%d", label, len(success_buf))
if label:
success_buf[-1] = np.array([True], dtype=bool)
logger.info("Terminal frame stamped next.success=True")
# ------------------------------------------------------------------
# Continuous recording mode (record_autonomous=True)
# ------------------------------------------------------------------
@@ -350,7 +408,12 @@ class DAggerStrategy(RolloutStrategy):
episode_start = time.perf_counter()
episodes_since_push = 0
episode_duration_s = self._episode_duration_s
logger.info("DAgger continuous recording started (episode_duration=%.0fs)", episode_duration_s)
num_episodes = self.config.num_episodes
logger.info(
"DAgger continuous recording started (episode_duration=%.0fs, target=%s eps)",
episode_duration_s,
num_episodes if num_episodes is not None else "",
)
with VideoEncodingManager(dataset):
try:
@@ -399,6 +462,7 @@ class DAggerStrategy(RolloutStrategy):
**action_frame,
"task": task_str,
"intervention": np.array([True], dtype=bool),
"next.success": np.array([False], dtype=bool),
}
dataset.add_frame(frame)
record_tick += 1
@@ -427,23 +491,32 @@ class DAggerStrategy(RolloutStrategy):
**action_frame,
"task": task_str,
"intervention": np.array([False], dtype=bool),
"next.success": np.array([False], dtype=bool),
}
dataset.add_frame(frame)
record_tick += 1
# Episode rotation derived from the video file-size target.
# Saving is deferred while a correction is ongoing so the
# episode boundary lands on a clean autonomous frame.
# Episode rotation: either the operator pressed success/failure,
# or the video file-size target was reached.
# Defer the save while a correction is ongoing so the episode
# boundary lands on a clean autonomous frame. The event stays
# set until we actually save, so it won't be lost.
manual_save = events.save_episode_requested.is_set()
elapsed = time.perf_counter() - episode_start
if elapsed >= episode_duration_s and phase != DAggerPhase.CORRECTING:
if (manual_save or elapsed >= episode_duration_s) and phase != DAggerPhase.CORRECTING:
if manual_save:
events.save_episode_requested.clear()
with self._episode_lock:
self._stamp_episode_success(dataset)
dataset.save_episode()
episodes_since_push += 1
self._needs_push.set()
save_reason = "manual save" if manual_save else f"elapsed {elapsed:.1f}s"
logger.info(
"Episode saved (total: %d, elapsed: %.1fs)",
"Episode saved (%s, total: %d)",
save_reason,
dataset.num_episodes,
elapsed,
)
log_say(f"Episode {dataset.num_episodes} saved", play_sounds)
@@ -451,6 +524,25 @@ class DAggerStrategy(RolloutStrategy):
self._background_push(dataset, cfg)
episodes_since_push = 0
if num_episodes is not None and dataset.num_episodes >= num_episodes:
logger.info("Target episode count reached (%d), stopping session", num_episodes)
log_say(f"All {num_episodes} episodes collected", play_sounds)
events.stop_recording.set()
break
# Pause after manual save: stop the policy, return robot to
# initial position, and wait for the operator to reset the
# environment and press SPACE.
if manual_save:
engine.pause()
events.phase = DAggerPhase.PAUSED
self._return_to_initial_position(ctx.hardware)
last_action = None
logger.info(
"Episode saved — paused for environment reset. Press SPACE to start next episode."
)
log_say("Reset the environment, then press space", play_sounds)
episode_start = time.perf_counter()
dt = time.perf_counter() - loop_start
@@ -465,10 +557,13 @@ class DAggerStrategy(RolloutStrategy):
logger.info("DAgger continuous control loop ended — pausing engine")
engine.pause()
with contextlib.suppress(Exception):
with self._episode_lock:
dataset.save_episode()
self._needs_push.set()
logger.info("Final in-progress episode saved")
buf = dataset.writer.episode_buffer
if buf and any(len(v) > 0 for v in buf.values() if isinstance(v, list)):
with self._episode_lock:
self._stamp_episode_success(dataset)
dataset.save_episode()
self._needs_push.set()
logger.info("Final in-progress episode saved")
# ------------------------------------------------------------------
# Corrections-only mode (record_autonomous=False)
@@ -540,6 +635,7 @@ class DAggerStrategy(RolloutStrategy):
# Correction ended -> save episode (blocking if not streaming)
if old_phase == DAggerPhase.CORRECTING and new_phase == DAggerPhase.PAUSED:
with self._episode_lock:
self._stamp_episode_success(dataset)
dataset.save_episode()
recorded += 1
self._needs_push.set()
@@ -581,6 +677,7 @@ class DAggerStrategy(RolloutStrategy):
**action_frame,
"task": task_str,
"intervention": np.array([True], dtype=bool),
"next.success": np.array([False], dtype=bool),
}
)
record_tick += 1
@@ -614,10 +711,13 @@ class DAggerStrategy(RolloutStrategy):
logger.info("DAgger corrections-only loop ended — pausing engine")
engine.pause()
with contextlib.suppress(Exception):
with self._episode_lock:
dataset.save_episode()
self._needs_push.set()
logger.info("Final in-progress episode saved")
buf = dataset.writer.episode_buffer
if buf and any(len(v) > 0 for v in buf.values() if isinstance(v, list)):
with self._episode_lock:
self._stamp_episode_success(dataset)
dataset.save_episode()
self._needs_push.set()
logger.info("Final in-progress episode saved")
# ------------------------------------------------------------------
# State-machine transition side-effects