mirror of
https://github.com/huggingface/lerobot.git
synced 2026-07-03 16:17:15 +00:00
feat(rollout): add episode success/failure labeling to DAgger strategy
Enable operators to mark episodes as success or failure during DAgger data collection. Pressing 's' or 'f' immediately saves the episode with the appropriate label and returns the robot to its initial position. - Add success/failure key bindings to DAggerKeyboardConfig - Add save_episode_requested event and episode_success state to DAggerEvents - Stamp next.success=True on terminal frame for successful episodes - Pause and return to initial position after manual save for env reset - Add num_episodes target to stop continuous recording automatically - Defer save during corrections to avoid splitting mid-intervention
This commit is contained in:
@@ -106,6 +106,8 @@ class DAggerKeyboardConfig:
|
||||
pause_resume: str = "space"
|
||||
correction: str = "tab"
|
||||
upload: str = "enter"
|
||||
success: str = "s"
|
||||
failure: str = "f"
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -165,6 +167,10 @@ class DAggerStrategyConfig(RolloutStrategyConfig):
|
||||
2. **correction** — toggle human correction recording.
|
||||
3. **upload** — push dataset to hub on demand (corrections-only mode).
|
||||
|
||||
Episode success labeling:
|
||||
4. **success** — mark current episode as successful.
|
||||
5. **failure** — mark current episode as failed.
|
||||
|
||||
When ``record_autonomous=False`` (default) only human-correction windows
|
||||
are recorded — each correction becomes its own episode. Set to ``True``
|
||||
to record both autonomous and correction frames with size-based episode
|
||||
|
||||
@@ -350,6 +350,11 @@ def build_rollout_context(
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
}
|
||||
dataset_features["next.success"] = {
|
||||
"dtype": "bool",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
}
|
||||
|
||||
repo_name = cfg.dataset.repo_id.split("/", 1)[-1]
|
||||
if not repo_name.startswith("rollout_"):
|
||||
|
||||
@@ -112,6 +112,11 @@ class DAggerEvents:
|
||||
# Session-level flags
|
||||
self.stop_recording = Event()
|
||||
self.upload_requested = Event()
|
||||
# Set when operator presses success/failure key to end the current episode.
|
||||
self.save_episode_requested = Event()
|
||||
|
||||
# Episode success labeling
|
||||
self._episode_success: bool | None = None
|
||||
|
||||
# -- Thread-safe phase access ------------------------------------------
|
||||
|
||||
@@ -155,7 +160,26 @@ class DAggerEvents:
|
||||
with self._lock:
|
||||
self._phase = DAggerPhase.AUTONOMOUS
|
||||
self._pending_transition = None
|
||||
self._episode_success = None
|
||||
self.upload_requested.clear()
|
||||
self.save_episode_requested.clear()
|
||||
|
||||
def mark_success(self) -> None:
|
||||
"""Mark the current episode as successful (called from input threads)."""
|
||||
with self._lock:
|
||||
self._episode_success = True
|
||||
|
||||
def mark_failure(self) -> None:
|
||||
"""Mark the current episode as failed (called from input threads)."""
|
||||
with self._lock:
|
||||
self._episode_success = False
|
||||
|
||||
def consume_episode_success(self) -> bool | None:
|
||||
"""Consume and reset the episode success label. Returns None if unlabeled."""
|
||||
with self._lock:
|
||||
result = self._episode_success
|
||||
self._episode_success = None
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -186,12 +210,20 @@ def _init_dagger_keyboard(events: DAggerEvents, cfg: DAggerKeyboardConfig):
|
||||
events.request_transition(key_to_event[name])
|
||||
if name == cfg.upload:
|
||||
events.upload_requested.set()
|
||||
if name == cfg.success:
|
||||
events.mark_success()
|
||||
events.save_episode_requested.set()
|
||||
logger.info("Episode marked as SUCCESS — saving")
|
||||
if name == cfg.failure:
|
||||
events.mark_failure()
|
||||
events.save_episode_requested.set()
|
||||
logger.info("Episode marked as FAILURE — saving")
|
||||
|
||||
return create_key_listener(
|
||||
dispatch,
|
||||
controls_help=(
|
||||
f"pause_resume='{cfg.pause_resume}', correction='{cfg.correction}', "
|
||||
f"upload='{cfg.upload}', ESC=stop"
|
||||
f"upload='{cfg.upload}', success='{cfg.success}', failure='{cfg.failure}', ESC=stop"
|
||||
),
|
||||
)
|
||||
|
||||
@@ -313,6 +345,31 @@ class DAggerStrategy(RolloutStrategy):
|
||||
)
|
||||
logger.info("DAgger strategy teardown complete")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Episode success labeling
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _stamp_episode_success(self, dataset) -> None:
|
||||
"""Set next.success on the terminal frame based on operator label.
|
||||
|
||||
Called just before save_episode(). If the operator pressed the success
|
||||
key during this episode, the last frame's next.success is set to True.
|
||||
Otherwise all frames remain False (unlabeled = assumed failure).
|
||||
"""
|
||||
buf = dataset.writer.episode_buffer
|
||||
if buf is None:
|
||||
return
|
||||
|
||||
success_buf = buf.get("next.success")
|
||||
if not success_buf:
|
||||
return
|
||||
|
||||
label = self._events.consume_episode_success()
|
||||
|
||||
if label:
|
||||
success_buf[-1] = np.array([True], dtype=bool)
|
||||
logger.info("Terminal frame stamped next.success=True")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Continuous recording mode (record_autonomous=True)
|
||||
# ------------------------------------------------------------------
|
||||
@@ -350,7 +407,12 @@ class DAggerStrategy(RolloutStrategy):
|
||||
episode_start = time.perf_counter()
|
||||
episodes_since_push = 0
|
||||
episode_duration_s = self._episode_duration_s
|
||||
logger.info("DAgger continuous recording started (episode_duration=%.0fs)", episode_duration_s)
|
||||
num_episodes = self.config.num_episodes
|
||||
logger.info(
|
||||
"DAgger continuous recording started (episode_duration=%.0fs, target=%s eps)",
|
||||
episode_duration_s,
|
||||
num_episodes if num_episodes is not None else "∞",
|
||||
)
|
||||
|
||||
with VideoEncodingManager(dataset):
|
||||
try:
|
||||
@@ -399,6 +461,7 @@ class DAggerStrategy(RolloutStrategy):
|
||||
**action_frame,
|
||||
"task": task_str,
|
||||
"intervention": np.array([True], dtype=bool),
|
||||
"next.success": np.array([False], dtype=bool),
|
||||
}
|
||||
dataset.add_frame(frame)
|
||||
record_tick += 1
|
||||
@@ -427,23 +490,32 @@ class DAggerStrategy(RolloutStrategy):
|
||||
**action_frame,
|
||||
"task": task_str,
|
||||
"intervention": np.array([False], dtype=bool),
|
||||
"next.success": np.array([False], dtype=bool),
|
||||
}
|
||||
dataset.add_frame(frame)
|
||||
record_tick += 1
|
||||
|
||||
# Episode rotation derived from the video file-size target.
|
||||
# Saving is deferred while a correction is ongoing so the
|
||||
# episode boundary lands on a clean autonomous frame.
|
||||
# Episode rotation: either the operator pressed success/failure,
|
||||
# or the video file-size target was reached.
|
||||
# Defer the save while a correction is ongoing so the episode
|
||||
# boundary lands on a clean autonomous frame. The event stays
|
||||
# set until we actually save, so it won't be lost.
|
||||
manual_save = events.save_episode_requested.is_set()
|
||||
|
||||
elapsed = time.perf_counter() - episode_start
|
||||
if elapsed >= episode_duration_s and phase != DAggerPhase.CORRECTING:
|
||||
if (manual_save or elapsed >= episode_duration_s) and phase != DAggerPhase.CORRECTING:
|
||||
if manual_save:
|
||||
events.save_episode_requested.clear()
|
||||
with self._episode_lock:
|
||||
self._stamp_episode_success(dataset)
|
||||
dataset.save_episode()
|
||||
episodes_since_push += 1
|
||||
self._needs_push.set()
|
||||
save_reason = "manual save" if manual_save else f"elapsed {elapsed:.1f}s"
|
||||
logger.info(
|
||||
"Episode saved (total: %d, elapsed: %.1fs)",
|
||||
"Episode saved (%s, total: %d)",
|
||||
save_reason,
|
||||
dataset.num_episodes,
|
||||
elapsed,
|
||||
)
|
||||
log_say(f"Episode {dataset.num_episodes} saved", play_sounds)
|
||||
|
||||
@@ -451,6 +523,24 @@ class DAggerStrategy(RolloutStrategy):
|
||||
self._background_push(dataset, cfg)
|
||||
episodes_since_push = 0
|
||||
|
||||
if num_episodes is not None and dataset.num_episodes >= num_episodes:
|
||||
logger.info("Target episode count reached (%d), stopping session", num_episodes)
|
||||
log_say(f"All {num_episodes} episodes collected", play_sounds)
|
||||
events.stop_recording.set()
|
||||
break
|
||||
|
||||
# Pause after manual save: stop the policy, return robot to
|
||||
# initial position, and wait for the operator to reset the
|
||||
# environment and press SPACE.
|
||||
if manual_save:
|
||||
engine.pause()
|
||||
events.phase = DAggerPhase.PAUSED
|
||||
self._return_to_initial_position(ctx.hardware)
|
||||
logger.info(
|
||||
"Episode saved — paused for environment reset. Press SPACE to start next episode."
|
||||
)
|
||||
log_say("Reset the environment, then press space", play_sounds)
|
||||
|
||||
episode_start = time.perf_counter()
|
||||
|
||||
dt = time.perf_counter() - loop_start
|
||||
@@ -466,6 +556,7 @@ class DAggerStrategy(RolloutStrategy):
|
||||
engine.pause()
|
||||
with contextlib.suppress(Exception):
|
||||
with self._episode_lock:
|
||||
self._stamp_episode_success(dataset)
|
||||
dataset.save_episode()
|
||||
self._needs_push.set()
|
||||
logger.info("Final in-progress episode saved")
|
||||
@@ -540,6 +631,7 @@ class DAggerStrategy(RolloutStrategy):
|
||||
# Correction ended -> save episode (blocking if not streaming)
|
||||
if old_phase == DAggerPhase.CORRECTING and new_phase == DAggerPhase.PAUSED:
|
||||
with self._episode_lock:
|
||||
self._stamp_episode_success(dataset)
|
||||
dataset.save_episode()
|
||||
recorded += 1
|
||||
self._needs_push.set()
|
||||
@@ -581,6 +673,7 @@ class DAggerStrategy(RolloutStrategy):
|
||||
**action_frame,
|
||||
"task": task_str,
|
||||
"intervention": np.array([True], dtype=bool),
|
||||
"next.success": np.array([False], dtype=bool),
|
||||
}
|
||||
)
|
||||
record_tick += 1
|
||||
@@ -615,6 +708,7 @@ class DAggerStrategy(RolloutStrategy):
|
||||
engine.pause()
|
||||
with contextlib.suppress(Exception):
|
||||
with self._episode_lock:
|
||||
self._stamp_episode_success(dataset)
|
||||
dataset.save_episode()
|
||||
self._needs_push.set()
|
||||
logger.info("Final in-progress episode saved")
|
||||
|
||||
Reference in New Issue
Block a user