From 44f76dbbf0b70a5a97875c32efb43b8c99eced1e Mon Sep 17 00:00:00 2001 From: Khalil Meftah Date: Thu, 2 Jul 2026 17:13:05 +0200 Subject: [PATCH] feat(rollout): add episode success/failure labeling to DAgger strategy Enable operators to mark episodes as success or failure during DAgger data collection. Pressing 's' or 'f' immediately saves the episode with the appropriate label and returns the robot to its initial position. - Add success/failure key bindings to DAggerKeyboardConfig - Add save_episode_requested event and episode_success state to DAggerEvents - Stamp next.success=True on terminal frame for successful episodes - Pause and return to initial position after manual save for env reset - Add num_episodes target to stop continuous recording automatically - Defer save during corrections to avoid splitting mid-intervention --- src/lerobot/rollout/configs.py | 6 ++ src/lerobot/rollout/context.py | 5 ++ src/lerobot/rollout/strategies/dagger.py | 110 +++++++++++++++++++++-- 3 files changed, 113 insertions(+), 8 deletions(-) diff --git a/src/lerobot/rollout/configs.py b/src/lerobot/rollout/configs.py index 60c47cfba..b726f5476 100644 --- a/src/lerobot/rollout/configs.py +++ b/src/lerobot/rollout/configs.py @@ -106,6 +106,8 @@ class DAggerKeyboardConfig: pause_resume: str = "space" correction: str = "tab" upload: str = "enter" + success: str = "s" + failure: str = "f" @dataclass @@ -165,6 +167,10 @@ class DAggerStrategyConfig(RolloutStrategyConfig): 2. **correction** — toggle human correction recording. 3. **upload** — push dataset to hub on demand (corrections-only mode). + Episode success labeling: + 4. **success** — mark current episode as successful. + 5. **failure** — mark current episode as failed. + When ``record_autonomous=False`` (default) only human-correction windows are recorded — each correction becomes its own episode. Set to ``True`` to record both autonomous and correction frames with size-based episode diff --git a/src/lerobot/rollout/context.py b/src/lerobot/rollout/context.py index 20a7d715a..863dc1058 100644 --- a/src/lerobot/rollout/context.py +++ b/src/lerobot/rollout/context.py @@ -350,6 +350,11 @@ def build_rollout_context( "shape": (1,), "names": None, } + dataset_features["next.success"] = { + "dtype": "bool", + "shape": (1,), + "names": None, + } repo_name = cfg.dataset.repo_id.split("/", 1)[-1] if not repo_name.startswith("rollout_"): diff --git a/src/lerobot/rollout/strategies/dagger.py b/src/lerobot/rollout/strategies/dagger.py index 21d1e8e98..01a6f2f70 100644 --- a/src/lerobot/rollout/strategies/dagger.py +++ b/src/lerobot/rollout/strategies/dagger.py @@ -112,6 +112,11 @@ class DAggerEvents: # Session-level flags self.stop_recording = Event() self.upload_requested = Event() + # Set when operator presses success/failure key to end the current episode. + self.save_episode_requested = Event() + + # Episode success labeling + self._episode_success: bool | None = None # -- Thread-safe phase access ------------------------------------------ @@ -155,7 +160,26 @@ class DAggerEvents: with self._lock: self._phase = DAggerPhase.AUTONOMOUS self._pending_transition = None + self._episode_success = None self.upload_requested.clear() + self.save_episode_requested.clear() + + def mark_success(self) -> None: + """Mark the current episode as successful (called from input threads).""" + with self._lock: + self._episode_success = True + + def mark_failure(self) -> None: + """Mark the current episode as failed (called from input threads).""" + with self._lock: + self._episode_success = False + + def consume_episode_success(self) -> bool | None: + """Consume and reset the episode success label. Returns None if unlabeled.""" + with self._lock: + result = self._episode_success + self._episode_success = None + return result # --------------------------------------------------------------------------- @@ -186,12 +210,20 @@ def _init_dagger_keyboard(events: DAggerEvents, cfg: DAggerKeyboardConfig): events.request_transition(key_to_event[name]) if name == cfg.upload: events.upload_requested.set() + if name == cfg.success: + events.mark_success() + events.save_episode_requested.set() + logger.info("Episode marked as SUCCESS — saving") + if name == cfg.failure: + events.mark_failure() + events.save_episode_requested.set() + logger.info("Episode marked as FAILURE — saving") return create_key_listener( dispatch, controls_help=( f"pause_resume='{cfg.pause_resume}', correction='{cfg.correction}', " - f"upload='{cfg.upload}', ESC=stop" + f"upload='{cfg.upload}', success='{cfg.success}', failure='{cfg.failure}', ESC=stop" ), ) @@ -313,6 +345,31 @@ class DAggerStrategy(RolloutStrategy): ) logger.info("DAgger strategy teardown complete") + # ------------------------------------------------------------------ + # Episode success labeling + # ------------------------------------------------------------------ + + def _stamp_episode_success(self, dataset) -> None: + """Set next.success on the terminal frame based on operator label. + + Called just before save_episode(). If the operator pressed the success + key during this episode, the last frame's next.success is set to True. + Otherwise all frames remain False (unlabeled = assumed failure). + """ + buf = dataset.writer.episode_buffer + if buf is None: + return + + success_buf = buf.get("next.success") + if not success_buf: + return + + label = self._events.consume_episode_success() + + if label: + success_buf[-1] = np.array([True], dtype=bool) + logger.info("Terminal frame stamped next.success=True") + # ------------------------------------------------------------------ # Continuous recording mode (record_autonomous=True) # ------------------------------------------------------------------ @@ -350,7 +407,12 @@ class DAggerStrategy(RolloutStrategy): episode_start = time.perf_counter() episodes_since_push = 0 episode_duration_s = self._episode_duration_s - logger.info("DAgger continuous recording started (episode_duration=%.0fs)", episode_duration_s) + num_episodes = self.config.num_episodes + logger.info( + "DAgger continuous recording started (episode_duration=%.0fs, target=%s eps)", + episode_duration_s, + num_episodes if num_episodes is not None else "∞", + ) with VideoEncodingManager(dataset): try: @@ -399,6 +461,7 @@ class DAggerStrategy(RolloutStrategy): **action_frame, "task": task_str, "intervention": np.array([True], dtype=bool), + "next.success": np.array([False], dtype=bool), } dataset.add_frame(frame) record_tick += 1 @@ -427,23 +490,32 @@ class DAggerStrategy(RolloutStrategy): **action_frame, "task": task_str, "intervention": np.array([False], dtype=bool), + "next.success": np.array([False], dtype=bool), } dataset.add_frame(frame) record_tick += 1 - # Episode rotation derived from the video file-size target. - # Saving is deferred while a correction is ongoing so the - # episode boundary lands on a clean autonomous frame. + # Episode rotation: either the operator pressed success/failure, + # or the video file-size target was reached. + # Defer the save while a correction is ongoing so the episode + # boundary lands on a clean autonomous frame. The event stays + # set until we actually save, so it won't be lost. + manual_save = events.save_episode_requested.is_set() + elapsed = time.perf_counter() - episode_start - if elapsed >= episode_duration_s and phase != DAggerPhase.CORRECTING: + if (manual_save or elapsed >= episode_duration_s) and phase != DAggerPhase.CORRECTING: + if manual_save: + events.save_episode_requested.clear() with self._episode_lock: + self._stamp_episode_success(dataset) dataset.save_episode() episodes_since_push += 1 self._needs_push.set() + save_reason = "manual save" if manual_save else f"elapsed {elapsed:.1f}s" logger.info( - "Episode saved (total: %d, elapsed: %.1fs)", + "Episode saved (%s, total: %d)", + save_reason, dataset.num_episodes, - elapsed, ) log_say(f"Episode {dataset.num_episodes} saved", play_sounds) @@ -451,6 +523,24 @@ class DAggerStrategy(RolloutStrategy): self._background_push(dataset, cfg) episodes_since_push = 0 + if num_episodes is not None and dataset.num_episodes >= num_episodes: + logger.info("Target episode count reached (%d), stopping session", num_episodes) + log_say(f"All {num_episodes} episodes collected", play_sounds) + events.stop_recording.set() + break + + # Pause after manual save: stop the policy, return robot to + # initial position, and wait for the operator to reset the + # environment and press SPACE. + if manual_save: + engine.pause() + events.phase = DAggerPhase.PAUSED + self._return_to_initial_position(ctx.hardware) + logger.info( + "Episode saved — paused for environment reset. Press SPACE to start next episode." + ) + log_say("Reset the environment, then press space", play_sounds) + episode_start = time.perf_counter() dt = time.perf_counter() - loop_start @@ -466,6 +556,7 @@ class DAggerStrategy(RolloutStrategy): engine.pause() with contextlib.suppress(Exception): with self._episode_lock: + self._stamp_episode_success(dataset) dataset.save_episode() self._needs_push.set() logger.info("Final in-progress episode saved") @@ -540,6 +631,7 @@ class DAggerStrategy(RolloutStrategy): # Correction ended -> save episode (blocking if not streaming) if old_phase == DAggerPhase.CORRECTING and new_phase == DAggerPhase.PAUSED: with self._episode_lock: + self._stamp_episode_success(dataset) dataset.save_episode() recorded += 1 self._needs_push.set() @@ -581,6 +673,7 @@ class DAggerStrategy(RolloutStrategy): **action_frame, "task": task_str, "intervention": np.array([True], dtype=bool), + "next.success": np.array([False], dtype=bool), } ) record_tick += 1 @@ -615,6 +708,7 @@ class DAggerStrategy(RolloutStrategy): engine.pause() with contextlib.suppress(Exception): with self._episode_lock: + self._stamp_episode_success(dataset) dataset.save_episode() self._needs_push.set() logger.info("Final in-progress episode saved")