feat(rollout): add episode success labeling to DAgger strategy

This commit is contained in:
Khalil Meftah
2026-06-22 15:08:05 +02:00
parent 4b779b1e99
commit 0d2ba54385
4 changed files with 179 additions and 1 deletions
+8
View File
@@ -106,6 +106,8 @@ class DAggerKeyboardConfig:
pause_resume: str = "space"
correction: str = "tab"
upload: str = "enter"
success: str = "s"
failure: str = "f"
@dataclass
@@ -119,6 +121,8 @@ class DAggerPedalConfig:
pause_resume: str = "KEY_A"
correction: str = "KEY_B"
upload: str = "KEY_C"
success: str = "KEY_D"
failure: str = "KEY_E"
@RolloutStrategyConfig.register_subclass("episodic")
@@ -165,6 +169,10 @@ class DAggerStrategyConfig(RolloutStrategyConfig):
2. **correction** — toggle human correction recording.
3. **upload** — push dataset to hub on demand (corrections-only mode).
Episode success labeling:
4. **success** — mark current episode as successful.
5. **failure** — mark current episode as failed.
When ``record_autonomous=False`` (default) only human-correction windows
are recorded — each correction becomes its own episode. Set to ``True``
to record both autonomous and correction frames with size-based episode
+5
View File
@@ -347,6 +347,11 @@ def build_rollout_context(
"shape": (1,),
"names": None,
}
dataset_features["next.success"] = {
"dtype": "bool",
"shape": (1,),
"names": None,
}
repo_name = cfg.dataset.repo_id.split("/", 1)[-1]
if not repo_name.startswith("rollout_"):
+69 -1
View File
@@ -129,6 +129,9 @@ class DAggerEvents:
self.stop_recording = Event()
self.upload_requested = Event()
# Episode success labeling
self._episode_success: bool | None = None
# -- Thread-safe phase access ------------------------------------------
@property
@@ -171,8 +174,26 @@ class DAggerEvents:
with self._lock:
self._phase = DAggerPhase.AUTONOMOUS
self._pending_transition = None
self._episode_success = None
self.upload_requested.clear()
def mark_success(self) -> None:
"""Mark the current episode as successful (called from input threads)."""
with self._lock:
self._episode_success = True
def mark_failure(self) -> None:
"""Mark the current episode as failed (called from input threads)."""
with self._lock:
self._episode_success = False
def consume_episode_success(self) -> bool | None:
"""Consume and reset the episode success label. Returns None if unlabeled."""
with self._lock:
result = self._episode_success
self._episode_success = None
return result
# ---------------------------------------------------------------------------
# Input device handlers
@@ -226,16 +247,25 @@ def _init_dagger_keyboard(events: DAggerEvents, cfg: DAggerKeyboardConfig):
events.request_transition(key_to_event[resolved])
if resolved == cfg.upload:
events.upload_requested.set()
if resolved == cfg.success:
events.mark_success()
logger.info("Episode marked as SUCCESS")
if resolved == cfg.failure:
events.mark_failure()
logger.info("Episode marked as FAILURE")
except Exception as e:
logger.debug("Key error: %s", e)
listener = keyboard.Listener(on_press=on_press)
listener.start()
logger.info(
"DAgger keyboard listener started (pause_resume='%s', correction='%s', upload='%s', ESC=stop)",
"DAgger keyboard listener started (pause_resume='%s', correction='%s', "
"upload='%s', success='%s', failure='%s', ESC=stop)",
cfg.pause_resume,
cfg.correction,
cfg.upload,
cfg.success,
cfg.failure,
)
return listener
@@ -255,6 +285,12 @@ def _init_dagger_pedal(events: DAggerEvents, cfg: DAggerPedalConfig):
events.request_transition(code_to_event[code])
if code == cfg.upload:
events.upload_requested.set()
if code == cfg.success:
events.mark_success()
logger.info("Episode marked as SUCCESS (pedal)")
if code == cfg.failure:
events.mark_failure()
logger.info("Episode marked as FAILURE (pedal)")
logger.info("Initializing DAgger foot pedal listener (device=%s)", cfg.device_path)
return start_pedal_listener(on_press, device_path=cfg.device_path)
@@ -357,6 +393,31 @@ class DAggerStrategy(RolloutStrategy):
)
logger.info("DAgger strategy teardown complete")
# ------------------------------------------------------------------
# Episode success labeling
# ------------------------------------------------------------------
def _stamp_episode_success(self, dataset) -> None:
"""Set next.success on the terminal frame based on operator label.
Called just before save_episode(). If the operator pressed the success
key during this episode, the last frame's next.success is set to True.
Otherwise all frames remain False (unlabeled = assumed failure).
"""
buf = dataset.writer.episode_buffer
if buf is None:
return
success_buf = buf.get("next.success")
if not success_buf:
return
label = self._events.consume_episode_success()
if label:
success_buf[-1] = np.array([True], dtype=bool)
logger.info("Terminal frame stamped next.success=True")
# ------------------------------------------------------------------
# Continuous recording mode (record_autonomous=True)
# ------------------------------------------------------------------
@@ -443,6 +504,7 @@ class DAggerStrategy(RolloutStrategy):
**action_frame,
"task": task_str,
"intervention": np.array([True], dtype=bool),
"next.success": np.array([False], dtype=bool),
}
dataset.add_frame(frame)
record_tick += 1
@@ -471,6 +533,7 @@ class DAggerStrategy(RolloutStrategy):
**action_frame,
"task": task_str,
"intervention": np.array([False], dtype=bool),
"next.success": np.array([False], dtype=bool),
}
dataset.add_frame(frame)
record_tick += 1
@@ -481,6 +544,7 @@ class DAggerStrategy(RolloutStrategy):
elapsed = time.perf_counter() - episode_start
if elapsed >= episode_duration_s and phase != DAggerPhase.CORRECTING:
with self._episode_lock:
self._stamp_episode_success(dataset)
dataset.save_episode()
episodes_since_push += 1
self._needs_push.set()
@@ -510,6 +574,7 @@ class DAggerStrategy(RolloutStrategy):
engine.pause()
with contextlib.suppress(Exception):
with self._episode_lock:
self._stamp_episode_success(dataset)
dataset.save_episode()
self._needs_push.set()
logger.info("Final in-progress episode saved")
@@ -584,6 +649,7 @@ class DAggerStrategy(RolloutStrategy):
# Correction ended -> save episode (blocking if not streaming)
if old_phase == DAggerPhase.CORRECTING and new_phase == DAggerPhase.PAUSED:
with self._episode_lock:
self._stamp_episode_success(dataset)
dataset.save_episode()
recorded += 1
self._needs_push.set()
@@ -625,6 +691,7 @@ class DAggerStrategy(RolloutStrategy):
**action_frame,
"task": task_str,
"intervention": np.array([True], dtype=bool),
"next.success": np.array([False], dtype=bool),
}
)
record_tick += 1
@@ -659,6 +726,7 @@ class DAggerStrategy(RolloutStrategy):
engine.pause()
with contextlib.suppress(Exception):
with self._episode_lock:
self._stamp_episode_success(dataset)
dataset.save_episode()
self._needs_push.set()
logger.info("Final in-progress episode saved")
+97
View File
@@ -338,6 +338,103 @@ def test_dagger_events_reset():
assert not events.upload_requested.is_set()
def test_dagger_mark_success():
"""mark_success sets the episode label to True."""
from lerobot.rollout.strategies import DAggerEvents
events = DAggerEvents()
assert events.consume_episode_success() is None
events.mark_success()
assert events.consume_episode_success() is True
# Consuming clears the label
assert events.consume_episode_success() is None
def test_dagger_mark_failure():
"""mark_failure sets the episode label to False."""
from lerobot.rollout.strategies import DAggerEvents
events = DAggerEvents()
events.mark_failure()
assert events.consume_episode_success() is False
def test_dagger_success_overrides_failure():
"""Last label wins — success after failure overrides."""
from lerobot.rollout.strategies import DAggerEvents
events = DAggerEvents()
events.mark_failure()
events.mark_success()
assert events.consume_episode_success() is True
def test_dagger_reset_clears_success_label():
"""reset() clears any pending episode success label."""
from lerobot.rollout.strategies import DAggerEvents
events = DAggerEvents()
events.mark_success()
events.reset()
assert events.consume_episode_success() is None
def test_stamp_episode_success_labels_terminal_frame():
"""_stamp_episode_success sets last frame's next.success to True."""
import numpy as np
from lerobot.rollout.strategies.dagger import DAggerStrategy
strategy = DAggerStrategy.__new__(DAggerStrategy)
strategy.config = MagicMock()
from lerobot.rollout.strategies import DAggerEvents
strategy._events = DAggerEvents()
strategy._events.mark_success()
dataset = MagicMock()
dataset.writer.episode_buffer = {
"next.success": [
np.array([False], dtype=bool),
np.array([False], dtype=bool),
np.array([False], dtype=bool),
],
}
strategy._stamp_episode_success(dataset)
assert dataset.writer.episode_buffer["next.success"][-1].item() is True
assert dataset.writer.episode_buffer["next.success"][0].item() is False
def test_stamp_episode_success_no_label_stays_false():
"""Without a label, all frames remain False."""
import numpy as np
from lerobot.rollout.strategies.dagger import DAggerStrategy
strategy = DAggerStrategy.__new__(DAggerStrategy)
strategy.config = MagicMock()
from lerobot.rollout.strategies import DAggerEvents
strategy._events = DAggerEvents()
dataset = MagicMock()
dataset.writer.episode_buffer = {
"next.success": [
np.array([False], dtype=bool),
np.array([False], dtype=bool),
],
}
strategy._stamp_episode_success(dataset)
assert all(v.item() is False for v in dataset.writer.episode_buffer["next.success"])
# ---------------------------------------------------------------------------
# Context dataclass
# ---------------------------------------------------------------------------