From 0d2ba54385d5f775b30f8b03455cf55db0223a75 Mon Sep 17 00:00:00 2001 From: Khalil Meftah Date: Mon, 22 Jun 2026 15:08:05 +0200 Subject: [PATCH] feat(rollout): add episode success labeling to DAgger strategy --- src/lerobot/rollout/configs.py | 8 ++ src/lerobot/rollout/context.py | 5 ++ src/lerobot/rollout/strategies/dagger.py | 70 ++++++++++++++++- tests/test_rollout.py | 97 ++++++++++++++++++++++++ 4 files changed, 179 insertions(+), 1 deletion(-) diff --git a/src/lerobot/rollout/configs.py b/src/lerobot/rollout/configs.py index 60c47cfba..254797e90 100644 --- a/src/lerobot/rollout/configs.py +++ b/src/lerobot/rollout/configs.py @@ -106,6 +106,8 @@ class DAggerKeyboardConfig: pause_resume: str = "space" correction: str = "tab" upload: str = "enter" + success: str = "s" + failure: str = "f" @dataclass @@ -119,6 +121,8 @@ class DAggerPedalConfig: pause_resume: str = "KEY_A" correction: str = "KEY_B" upload: str = "KEY_C" + success: str = "KEY_D" + failure: str = "KEY_E" @RolloutStrategyConfig.register_subclass("episodic") @@ -165,6 +169,10 @@ class DAggerStrategyConfig(RolloutStrategyConfig): 2. **correction** — toggle human correction recording. 3. **upload** — push dataset to hub on demand (corrections-only mode). + Episode success labeling: + 4. **success** — mark current episode as successful. + 5. **failure** — mark current episode as failed. + When ``record_autonomous=False`` (default) only human-correction windows are recorded — each correction becomes its own episode. Set to ``True`` to record both autonomous and correction frames with size-based episode diff --git a/src/lerobot/rollout/context.py b/src/lerobot/rollout/context.py index bf5fa0fd4..5e3b60674 100644 --- a/src/lerobot/rollout/context.py +++ b/src/lerobot/rollout/context.py @@ -347,6 +347,11 @@ def build_rollout_context( "shape": (1,), "names": None, } + dataset_features["next.success"] = { + "dtype": "bool", + "shape": (1,), + "names": None, + } repo_name = cfg.dataset.repo_id.split("/", 1)[-1] if not repo_name.startswith("rollout_"): diff --git a/src/lerobot/rollout/strategies/dagger.py b/src/lerobot/rollout/strategies/dagger.py index 8791a5502..b46828f0c 100644 --- a/src/lerobot/rollout/strategies/dagger.py +++ b/src/lerobot/rollout/strategies/dagger.py @@ -129,6 +129,9 @@ class DAggerEvents: self.stop_recording = Event() self.upload_requested = Event() + # Episode success labeling + self._episode_success: bool | None = None + # -- Thread-safe phase access ------------------------------------------ @property @@ -171,8 +174,26 @@ class DAggerEvents: with self._lock: self._phase = DAggerPhase.AUTONOMOUS self._pending_transition = None + self._episode_success = None self.upload_requested.clear() + def mark_success(self) -> None: + """Mark the current episode as successful (called from input threads).""" + with self._lock: + self._episode_success = True + + def mark_failure(self) -> None: + """Mark the current episode as failed (called from input threads).""" + with self._lock: + self._episode_success = False + + def consume_episode_success(self) -> bool | None: + """Consume and reset the episode success label. Returns None if unlabeled.""" + with self._lock: + result = self._episode_success + self._episode_success = None + return result + # --------------------------------------------------------------------------- # Input device handlers @@ -226,16 +247,25 @@ def _init_dagger_keyboard(events: DAggerEvents, cfg: DAggerKeyboardConfig): events.request_transition(key_to_event[resolved]) if resolved == cfg.upload: events.upload_requested.set() + if resolved == cfg.success: + events.mark_success() + logger.info("Episode marked as SUCCESS") + if resolved == cfg.failure: + events.mark_failure() + logger.info("Episode marked as FAILURE") except Exception as e: logger.debug("Key error: %s", e) listener = keyboard.Listener(on_press=on_press) listener.start() logger.info( - "DAgger keyboard listener started (pause_resume='%s', correction='%s', upload='%s', ESC=stop)", + "DAgger keyboard listener started (pause_resume='%s', correction='%s', " + "upload='%s', success='%s', failure='%s', ESC=stop)", cfg.pause_resume, cfg.correction, cfg.upload, + cfg.success, + cfg.failure, ) return listener @@ -255,6 +285,12 @@ def _init_dagger_pedal(events: DAggerEvents, cfg: DAggerPedalConfig): events.request_transition(code_to_event[code]) if code == cfg.upload: events.upload_requested.set() + if code == cfg.success: + events.mark_success() + logger.info("Episode marked as SUCCESS (pedal)") + if code == cfg.failure: + events.mark_failure() + logger.info("Episode marked as FAILURE (pedal)") logger.info("Initializing DAgger foot pedal listener (device=%s)", cfg.device_path) return start_pedal_listener(on_press, device_path=cfg.device_path) @@ -357,6 +393,31 @@ class DAggerStrategy(RolloutStrategy): ) logger.info("DAgger strategy teardown complete") + # ------------------------------------------------------------------ + # Episode success labeling + # ------------------------------------------------------------------ + + def _stamp_episode_success(self, dataset) -> None: + """Set next.success on the terminal frame based on operator label. + + Called just before save_episode(). If the operator pressed the success + key during this episode, the last frame's next.success is set to True. + Otherwise all frames remain False (unlabeled = assumed failure). + """ + buf = dataset.writer.episode_buffer + if buf is None: + return + + success_buf = buf.get("next.success") + if not success_buf: + return + + label = self._events.consume_episode_success() + + if label: + success_buf[-1] = np.array([True], dtype=bool) + logger.info("Terminal frame stamped next.success=True") + # ------------------------------------------------------------------ # Continuous recording mode (record_autonomous=True) # ------------------------------------------------------------------ @@ -443,6 +504,7 @@ class DAggerStrategy(RolloutStrategy): **action_frame, "task": task_str, "intervention": np.array([True], dtype=bool), + "next.success": np.array([False], dtype=bool), } dataset.add_frame(frame) record_tick += 1 @@ -471,6 +533,7 @@ class DAggerStrategy(RolloutStrategy): **action_frame, "task": task_str, "intervention": np.array([False], dtype=bool), + "next.success": np.array([False], dtype=bool), } dataset.add_frame(frame) record_tick += 1 @@ -481,6 +544,7 @@ class DAggerStrategy(RolloutStrategy): elapsed = time.perf_counter() - episode_start if elapsed >= episode_duration_s and phase != DAggerPhase.CORRECTING: with self._episode_lock: + self._stamp_episode_success(dataset) dataset.save_episode() episodes_since_push += 1 self._needs_push.set() @@ -510,6 +574,7 @@ class DAggerStrategy(RolloutStrategy): engine.pause() with contextlib.suppress(Exception): with self._episode_lock: + self._stamp_episode_success(dataset) dataset.save_episode() self._needs_push.set() logger.info("Final in-progress episode saved") @@ -584,6 +649,7 @@ class DAggerStrategy(RolloutStrategy): # Correction ended -> save episode (blocking if not streaming) if old_phase == DAggerPhase.CORRECTING and new_phase == DAggerPhase.PAUSED: with self._episode_lock: + self._stamp_episode_success(dataset) dataset.save_episode() recorded += 1 self._needs_push.set() @@ -625,6 +691,7 @@ class DAggerStrategy(RolloutStrategy): **action_frame, "task": task_str, "intervention": np.array([True], dtype=bool), + "next.success": np.array([False], dtype=bool), } ) record_tick += 1 @@ -659,6 +726,7 @@ class DAggerStrategy(RolloutStrategy): engine.pause() with contextlib.suppress(Exception): with self._episode_lock: + self._stamp_episode_success(dataset) dataset.save_episode() self._needs_push.set() logger.info("Final in-progress episode saved") diff --git a/tests/test_rollout.py b/tests/test_rollout.py index 85a29ff4c..cc1c48b87 100644 --- a/tests/test_rollout.py +++ b/tests/test_rollout.py @@ -338,6 +338,103 @@ def test_dagger_events_reset(): assert not events.upload_requested.is_set() +def test_dagger_mark_success(): + """mark_success sets the episode label to True.""" + from lerobot.rollout.strategies import DAggerEvents + + events = DAggerEvents() + assert events.consume_episode_success() is None + + events.mark_success() + assert events.consume_episode_success() is True + # Consuming clears the label + assert events.consume_episode_success() is None + + +def test_dagger_mark_failure(): + """mark_failure sets the episode label to False.""" + from lerobot.rollout.strategies import DAggerEvents + + events = DAggerEvents() + events.mark_failure() + assert events.consume_episode_success() is False + + +def test_dagger_success_overrides_failure(): + """Last label wins — success after failure overrides.""" + from lerobot.rollout.strategies import DAggerEvents + + events = DAggerEvents() + events.mark_failure() + events.mark_success() + assert events.consume_episode_success() is True + + +def test_dagger_reset_clears_success_label(): + """reset() clears any pending episode success label.""" + from lerobot.rollout.strategies import DAggerEvents + + events = DAggerEvents() + events.mark_success() + events.reset() + assert events.consume_episode_success() is None + + +def test_stamp_episode_success_labels_terminal_frame(): + """_stamp_episode_success sets last frame's next.success to True.""" + import numpy as np + + from lerobot.rollout.strategies.dagger import DAggerStrategy + + strategy = DAggerStrategy.__new__(DAggerStrategy) + strategy.config = MagicMock() + + from lerobot.rollout.strategies import DAggerEvents + + strategy._events = DAggerEvents() + strategy._events.mark_success() + + dataset = MagicMock() + dataset.writer.episode_buffer = { + "next.success": [ + np.array([False], dtype=bool), + np.array([False], dtype=bool), + np.array([False], dtype=bool), + ], + } + + strategy._stamp_episode_success(dataset) + + assert dataset.writer.episode_buffer["next.success"][-1].item() is True + assert dataset.writer.episode_buffer["next.success"][0].item() is False + + +def test_stamp_episode_success_no_label_stays_false(): + """Without a label, all frames remain False.""" + import numpy as np + + from lerobot.rollout.strategies.dagger import DAggerStrategy + + strategy = DAggerStrategy.__new__(DAggerStrategy) + strategy.config = MagicMock() + + from lerobot.rollout.strategies import DAggerEvents + + strategy._events = DAggerEvents() + + dataset = MagicMock() + dataset.writer.episode_buffer = { + "next.success": [ + np.array([False], dtype=bool), + np.array([False], dtype=bool), + ], + } + + strategy._stamp_episode_success(dataset) + + assert all(v.item() is False for v in dataset.writer.episode_buffer["next.success"]) + + # --------------------------------------------------------------------------- # Context dataclass # ---------------------------------------------------------------------------