mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-24 19:57:27 +00:00
feat(rollout): add episode success labeling to DAgger strategy
This commit is contained in:
@@ -106,6 +106,8 @@ class DAggerKeyboardConfig:
|
||||
pause_resume: str = "space"
|
||||
correction: str = "tab"
|
||||
upload: str = "enter"
|
||||
success: str = "s"
|
||||
failure: str = "f"
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -119,6 +121,8 @@ class DAggerPedalConfig:
|
||||
pause_resume: str = "KEY_A"
|
||||
correction: str = "KEY_B"
|
||||
upload: str = "KEY_C"
|
||||
success: str = "KEY_D"
|
||||
failure: str = "KEY_E"
|
||||
|
||||
|
||||
@RolloutStrategyConfig.register_subclass("episodic")
|
||||
@@ -165,6 +169,10 @@ class DAggerStrategyConfig(RolloutStrategyConfig):
|
||||
2. **correction** — toggle human correction recording.
|
||||
3. **upload** — push dataset to hub on demand (corrections-only mode).
|
||||
|
||||
Episode success labeling:
|
||||
4. **success** — mark current episode as successful.
|
||||
5. **failure** — mark current episode as failed.
|
||||
|
||||
When ``record_autonomous=False`` (default) only human-correction windows
|
||||
are recorded — each correction becomes its own episode. Set to ``True``
|
||||
to record both autonomous and correction frames with size-based episode
|
||||
|
||||
@@ -347,6 +347,11 @@ def build_rollout_context(
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
}
|
||||
dataset_features["next.success"] = {
|
||||
"dtype": "bool",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
}
|
||||
|
||||
repo_name = cfg.dataset.repo_id.split("/", 1)[-1]
|
||||
if not repo_name.startswith("rollout_"):
|
||||
|
||||
@@ -129,6 +129,9 @@ class DAggerEvents:
|
||||
self.stop_recording = Event()
|
||||
self.upload_requested = Event()
|
||||
|
||||
# Episode success labeling
|
||||
self._episode_success: bool | None = None
|
||||
|
||||
# -- Thread-safe phase access ------------------------------------------
|
||||
|
||||
@property
|
||||
@@ -171,8 +174,26 @@ class DAggerEvents:
|
||||
with self._lock:
|
||||
self._phase = DAggerPhase.AUTONOMOUS
|
||||
self._pending_transition = None
|
||||
self._episode_success = None
|
||||
self.upload_requested.clear()
|
||||
|
||||
def mark_success(self) -> None:
|
||||
"""Mark the current episode as successful (called from input threads)."""
|
||||
with self._lock:
|
||||
self._episode_success = True
|
||||
|
||||
def mark_failure(self) -> None:
|
||||
"""Mark the current episode as failed (called from input threads)."""
|
||||
with self._lock:
|
||||
self._episode_success = False
|
||||
|
||||
def consume_episode_success(self) -> bool | None:
|
||||
"""Consume and reset the episode success label. Returns None if unlabeled."""
|
||||
with self._lock:
|
||||
result = self._episode_success
|
||||
self._episode_success = None
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Input device handlers
|
||||
@@ -226,16 +247,25 @@ def _init_dagger_keyboard(events: DAggerEvents, cfg: DAggerKeyboardConfig):
|
||||
events.request_transition(key_to_event[resolved])
|
||||
if resolved == cfg.upload:
|
||||
events.upload_requested.set()
|
||||
if resolved == cfg.success:
|
||||
events.mark_success()
|
||||
logger.info("Episode marked as SUCCESS")
|
||||
if resolved == cfg.failure:
|
||||
events.mark_failure()
|
||||
logger.info("Episode marked as FAILURE")
|
||||
except Exception as e:
|
||||
logger.debug("Key error: %s", e)
|
||||
|
||||
listener = keyboard.Listener(on_press=on_press)
|
||||
listener.start()
|
||||
logger.info(
|
||||
"DAgger keyboard listener started (pause_resume='%s', correction='%s', upload='%s', ESC=stop)",
|
||||
"DAgger keyboard listener started (pause_resume='%s', correction='%s', "
|
||||
"upload='%s', success='%s', failure='%s', ESC=stop)",
|
||||
cfg.pause_resume,
|
||||
cfg.correction,
|
||||
cfg.upload,
|
||||
cfg.success,
|
||||
cfg.failure,
|
||||
)
|
||||
return listener
|
||||
|
||||
@@ -255,6 +285,12 @@ def _init_dagger_pedal(events: DAggerEvents, cfg: DAggerPedalConfig):
|
||||
events.request_transition(code_to_event[code])
|
||||
if code == cfg.upload:
|
||||
events.upload_requested.set()
|
||||
if code == cfg.success:
|
||||
events.mark_success()
|
||||
logger.info("Episode marked as SUCCESS (pedal)")
|
||||
if code == cfg.failure:
|
||||
events.mark_failure()
|
||||
logger.info("Episode marked as FAILURE (pedal)")
|
||||
|
||||
logger.info("Initializing DAgger foot pedal listener (device=%s)", cfg.device_path)
|
||||
return start_pedal_listener(on_press, device_path=cfg.device_path)
|
||||
@@ -357,6 +393,31 @@ class DAggerStrategy(RolloutStrategy):
|
||||
)
|
||||
logger.info("DAgger strategy teardown complete")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Episode success labeling
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _stamp_episode_success(self, dataset) -> None:
|
||||
"""Set next.success on the terminal frame based on operator label.
|
||||
|
||||
Called just before save_episode(). If the operator pressed the success
|
||||
key during this episode, the last frame's next.success is set to True.
|
||||
Otherwise all frames remain False (unlabeled = assumed failure).
|
||||
"""
|
||||
buf = dataset.writer.episode_buffer
|
||||
if buf is None:
|
||||
return
|
||||
|
||||
success_buf = buf.get("next.success")
|
||||
if not success_buf:
|
||||
return
|
||||
|
||||
label = self._events.consume_episode_success()
|
||||
|
||||
if label:
|
||||
success_buf[-1] = np.array([True], dtype=bool)
|
||||
logger.info("Terminal frame stamped next.success=True")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Continuous recording mode (record_autonomous=True)
|
||||
# ------------------------------------------------------------------
|
||||
@@ -443,6 +504,7 @@ class DAggerStrategy(RolloutStrategy):
|
||||
**action_frame,
|
||||
"task": task_str,
|
||||
"intervention": np.array([True], dtype=bool),
|
||||
"next.success": np.array([False], dtype=bool),
|
||||
}
|
||||
dataset.add_frame(frame)
|
||||
record_tick += 1
|
||||
@@ -471,6 +533,7 @@ class DAggerStrategy(RolloutStrategy):
|
||||
**action_frame,
|
||||
"task": task_str,
|
||||
"intervention": np.array([False], dtype=bool),
|
||||
"next.success": np.array([False], dtype=bool),
|
||||
}
|
||||
dataset.add_frame(frame)
|
||||
record_tick += 1
|
||||
@@ -481,6 +544,7 @@ class DAggerStrategy(RolloutStrategy):
|
||||
elapsed = time.perf_counter() - episode_start
|
||||
if elapsed >= episode_duration_s and phase != DAggerPhase.CORRECTING:
|
||||
with self._episode_lock:
|
||||
self._stamp_episode_success(dataset)
|
||||
dataset.save_episode()
|
||||
episodes_since_push += 1
|
||||
self._needs_push.set()
|
||||
@@ -510,6 +574,7 @@ class DAggerStrategy(RolloutStrategy):
|
||||
engine.pause()
|
||||
with contextlib.suppress(Exception):
|
||||
with self._episode_lock:
|
||||
self._stamp_episode_success(dataset)
|
||||
dataset.save_episode()
|
||||
self._needs_push.set()
|
||||
logger.info("Final in-progress episode saved")
|
||||
@@ -584,6 +649,7 @@ class DAggerStrategy(RolloutStrategy):
|
||||
# Correction ended -> save episode (blocking if not streaming)
|
||||
if old_phase == DAggerPhase.CORRECTING and new_phase == DAggerPhase.PAUSED:
|
||||
with self._episode_lock:
|
||||
self._stamp_episode_success(dataset)
|
||||
dataset.save_episode()
|
||||
recorded += 1
|
||||
self._needs_push.set()
|
||||
@@ -625,6 +691,7 @@ class DAggerStrategy(RolloutStrategy):
|
||||
**action_frame,
|
||||
"task": task_str,
|
||||
"intervention": np.array([True], dtype=bool),
|
||||
"next.success": np.array([False], dtype=bool),
|
||||
}
|
||||
)
|
||||
record_tick += 1
|
||||
@@ -659,6 +726,7 @@ class DAggerStrategy(RolloutStrategy):
|
||||
engine.pause()
|
||||
with contextlib.suppress(Exception):
|
||||
with self._episode_lock:
|
||||
self._stamp_episode_success(dataset)
|
||||
dataset.save_episode()
|
||||
self._needs_push.set()
|
||||
logger.info("Final in-progress episode saved")
|
||||
|
||||
@@ -338,6 +338,103 @@ def test_dagger_events_reset():
|
||||
assert not events.upload_requested.is_set()
|
||||
|
||||
|
||||
def test_dagger_mark_success():
|
||||
"""mark_success sets the episode label to True."""
|
||||
from lerobot.rollout.strategies import DAggerEvents
|
||||
|
||||
events = DAggerEvents()
|
||||
assert events.consume_episode_success() is None
|
||||
|
||||
events.mark_success()
|
||||
assert events.consume_episode_success() is True
|
||||
# Consuming clears the label
|
||||
assert events.consume_episode_success() is None
|
||||
|
||||
|
||||
def test_dagger_mark_failure():
|
||||
"""mark_failure sets the episode label to False."""
|
||||
from lerobot.rollout.strategies import DAggerEvents
|
||||
|
||||
events = DAggerEvents()
|
||||
events.mark_failure()
|
||||
assert events.consume_episode_success() is False
|
||||
|
||||
|
||||
def test_dagger_success_overrides_failure():
|
||||
"""Last label wins — success after failure overrides."""
|
||||
from lerobot.rollout.strategies import DAggerEvents
|
||||
|
||||
events = DAggerEvents()
|
||||
events.mark_failure()
|
||||
events.mark_success()
|
||||
assert events.consume_episode_success() is True
|
||||
|
||||
|
||||
def test_dagger_reset_clears_success_label():
|
||||
"""reset() clears any pending episode success label."""
|
||||
from lerobot.rollout.strategies import DAggerEvents
|
||||
|
||||
events = DAggerEvents()
|
||||
events.mark_success()
|
||||
events.reset()
|
||||
assert events.consume_episode_success() is None
|
||||
|
||||
|
||||
def test_stamp_episode_success_labels_terminal_frame():
|
||||
"""_stamp_episode_success sets last frame's next.success to True."""
|
||||
import numpy as np
|
||||
|
||||
from lerobot.rollout.strategies.dagger import DAggerStrategy
|
||||
|
||||
strategy = DAggerStrategy.__new__(DAggerStrategy)
|
||||
strategy.config = MagicMock()
|
||||
|
||||
from lerobot.rollout.strategies import DAggerEvents
|
||||
|
||||
strategy._events = DAggerEvents()
|
||||
strategy._events.mark_success()
|
||||
|
||||
dataset = MagicMock()
|
||||
dataset.writer.episode_buffer = {
|
||||
"next.success": [
|
||||
np.array([False], dtype=bool),
|
||||
np.array([False], dtype=bool),
|
||||
np.array([False], dtype=bool),
|
||||
],
|
||||
}
|
||||
|
||||
strategy._stamp_episode_success(dataset)
|
||||
|
||||
assert dataset.writer.episode_buffer["next.success"][-1].item() is True
|
||||
assert dataset.writer.episode_buffer["next.success"][0].item() is False
|
||||
|
||||
|
||||
def test_stamp_episode_success_no_label_stays_false():
|
||||
"""Without a label, all frames remain False."""
|
||||
import numpy as np
|
||||
|
||||
from lerobot.rollout.strategies.dagger import DAggerStrategy
|
||||
|
||||
strategy = DAggerStrategy.__new__(DAggerStrategy)
|
||||
strategy.config = MagicMock()
|
||||
|
||||
from lerobot.rollout.strategies import DAggerEvents
|
||||
|
||||
strategy._events = DAggerEvents()
|
||||
|
||||
dataset = MagicMock()
|
||||
dataset.writer.episode_buffer = {
|
||||
"next.success": [
|
||||
np.array([False], dtype=bool),
|
||||
np.array([False], dtype=bool),
|
||||
],
|
||||
}
|
||||
|
||||
strategy._stamp_episode_success(dataset)
|
||||
|
||||
assert all(v.item() is False for v in dataset.writer.episode_buffer["next.success"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Context dataclass
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
Reference in New Issue
Block a user