diff --git a/src/lerobot/annotations/steerable_pipeline/config.py b/src/lerobot/annotations/steerable_pipeline/config.py index ff9a11f4e..50372d8f2 100644 --- a/src/lerobot/annotations/steerable_pipeline/config.py +++ b/src/lerobot/annotations/steerable_pipeline/config.py @@ -62,8 +62,23 @@ class Module2Config: """Module 2 hyperparameters: interjections + paired speech.""" enabled: bool = True - max_interjections_per_episode: int = 1 + max_interjections_per_episode: int = 3 + """Number of mid-episode interjections to generate per episode. Each + creates a paired ``(interjection, speech)`` event row plus triggers a + ``plan`` refresh at the same timestamp via Module 1. Bumped from the + original ``1`` after qwen36moe-10 showed plan/interjection coverage + was too sparse for Hi Robot-style training.""" interjection_min_t: float = 2.0 + interjection_window_seconds: float = 2.0 + """How many seconds of video to attach to the interjection prompt as + visual context. Without this the VLM only sees a single frozen frame + and writes generic interjections that aren't grounded in the actual + motion happening at the chosen timestamp.""" + interjection_window_frames: int = 4 + """How many frames to sample over ``interjection_window_seconds``. + Default 4 ⇒ ~0.5 fps over the leading 2 seconds — enough for the + model to read the ongoing motion, cheap enough to keep prompt size + bounded for the 32k context.""" @dataclass diff --git a/src/lerobot/annotations/steerable_pipeline/executor.py b/src/lerobot/annotations/steerable_pipeline/executor.py index a6d73a32c..b24d698d6 100644 --- a/src/lerobot/annotations/steerable_pipeline/executor.py +++ b/src/lerobot/annotations/steerable_pipeline/executor.py @@ -110,7 +110,9 @@ class Executor: # Phase 1: Module 1 (plan + subtasks + memory) phases.append(self._run_module_phase("module_1", records, staging_dir, self.module_1)) - # Phase 2: Module 2 (interjections + speech) + # Phase 2: Module 2 (interjections + speech). Module 2 reads + # Module 1's subtask rows from the same staging tree to ground + # the interjection prompt in the correct local subtask. phases.append(self._run_module_phase("module_2", records, staging_dir, self.module_2)) # Phase 3: Module 1 plan-update pass at interjection timestamps. phases.append(self._run_plan_update_phase(records, staging_dir)) @@ -198,10 +200,16 @@ class Executor: processed = 0 for record in records: staging = EpisodeStaging(staging_dir, record.episode_index) - interjection_times = [ - row["timestamp"] for row in staging.read("module_2") if row.get("style") == "interjection" + interjection_rows = [ + row + for row in staging.read("module_2") + if row.get("style") == "interjection" ] + interjection_times = [float(row["timestamp"]) for row in interjection_rows] + interjection_texts = [str(row.get("content") or "") for row in interjection_rows] if interjection_times: - self.module_1.run_plan_updates(record, staging, interjection_times) + self.module_1.run_plan_updates( + record, staging, interjection_times, interjection_texts + ) processed += 1 return PhaseResult(name="module_1_plan_update", episodes_processed=processed, episodes_skipped=0) diff --git a/src/lerobot/annotations/steerable_pipeline/modules/interjections_and_speech.py b/src/lerobot/annotations/steerable_pipeline/modules/interjections_and_speech.py index d9b19959a..b65b08b6a 100644 --- a/src/lerobot/annotations/steerable_pipeline/modules/interjections_and_speech.py +++ b/src/lerobot/annotations/steerable_pipeline/modules/interjections_and_speech.py @@ -72,9 +72,37 @@ class InterjectionsAndSpeechModule: initial = self._initial_speech(record) if initial: rows.append(speech_atom(t0, initial)) - rows.extend(self._mid_episode_interjections(record)) + # Pull Module 1's subtask spans for this episode so the + # interjection prompt can ground itself in the actual current + # subtask at each chosen timestamp. Module 1 ran first. + subtask_spans = self._read_subtask_spans(staging) + rows.extend(self._mid_episode_interjections(record, subtask_spans)) staging.write("module_2", rows) + @staticmethod + def _read_subtask_spans(staging: EpisodeStaging) -> list[dict[str, Any]]: + rows = [r for r in staging.read("module_1") if r.get("style") == "subtask"] + rows.sort(key=lambda r: float(r["timestamp"])) + spans: list[dict[str, Any]] = [] + last_t: float | None = None + for r in rows: + t = float(r["timestamp"]) + if last_t is not None and spans: + spans[-1]["end"] = t + spans.append({"text": r.get("content") or "", "start": t, "end": t}) + last_t = t + return spans + + @staticmethod + def _subtask_at(spans: Sequence[dict[str, Any]], t: float) -> str | None: + current: str | None = None + for span in spans: + if float(span["start"]) <= t: + current = span.get("text") + else: + break + return current + def _initial_speech(self, record: EpisodeRecord) -> str | None: prompt = load_prompt("module_2_initial_speech").format( episode_task=record.episode_task, @@ -87,7 +115,11 @@ class InterjectionsAndSpeechModule: return text return None - def _mid_episode_interjections(self, record: EpisodeRecord) -> list[dict[str, Any]]: + def _mid_episode_interjections( + self, + record: EpisodeRecord, + subtask_spans: Sequence[dict[str, Any]], + ) -> list[dict[str, Any]]: if self.config.max_interjections_per_episode <= 0: return [] # Deterministic per-episode RNG so reruns are stable across SLURM jobs. @@ -95,20 +127,30 @@ class InterjectionsAndSpeechModule: candidate_ts = [t for t in record.frame_timestamps if t >= self.config.interjection_min_t] if not candidate_ts: return [] - n = min(self.config.max_interjections_per_episode, len(candidate_ts) // 4) + # Pick at most ``max_interjections_per_episode`` distinct timestamps. + # Previously capped at ``len(candidate_ts) // 4`` — that floor was + # only relevant for very short episodes; for any real ~20-30s + # episode it had no effect, but it silently set the count to 0 on + # short fixtures. Just take ``min(max, len)`` directly. + n = min(self.config.max_interjections_per_episode, len(candidate_ts)) if n <= 0: return [] chosen = sorted(rng.sample(candidate_ts, n)) + out: list[dict[str, Any]] = [] for t in chosen: t_snap = _snap_to_frame(t, record.frame_timestamps) - current_subtask = record.episode_task + window_ts = self._window_timestamps(t_snap, record.frame_timestamps) + current_subtask = ( + self._subtask_at(subtask_spans, t_snap) or record.episode_task + ) prompt = load_prompt("module_2_interjection").format( episode_task=record.episode_task, current_subtask=current_subtask, timestamp=t_snap, + window_seconds=self.config.interjection_window_seconds, ) - images = self.frame_provider.frames_at(record, [t_snap]) + images = self.frame_provider.frames_at(record, window_ts) content = [*to_image_blocks(images), {"type": "text", "text": prompt}] messages = [{"role": "user", "content": content}] result = self.vlm.generate_json([messages])[0] @@ -131,3 +173,31 @@ class InterjectionsAndSpeechModule: ) out.append(speech_atom(t_snap, speech_text.strip())) return out + + def _window_timestamps( + self, t_anchor: float, frame_timestamps: Sequence[float] + ) -> list[float]: + """Return a small set of frame timestamps spanning the lead-up to ``t``. + + The VLM receives roughly ``num_frames`` frames over the + ``window_seconds`` immediately before ``t_anchor``, snapped to + actual source frame timestamps. This gives the interjection + prompt enough temporal context to read what's visibly happening + instead of looking at one frozen frame. + """ + if not frame_timestamps: + return [t_anchor] + n = max(1, int(self.config.interjection_window_frames)) + if n == 1: + return [t_anchor] + window = float(self.config.interjection_window_seconds) + step = window / max(1, n - 1) + targets = [t_anchor - step * (n - 1 - i) for i in range(n)] + snapped: list[float] = [] + seen: set[float] = set() + for tgt in targets: + t = _snap_to_frame(max(0.0, tgt), frame_timestamps) + if t not in seen: + seen.add(t) + snapped.append(t) + return snapped or [t_anchor] diff --git a/src/lerobot/annotations/steerable_pipeline/modules/plan_subtasks_memory.py b/src/lerobot/annotations/steerable_pipeline/modules/plan_subtasks_memory.py index 6c74b3134..c125e3640 100644 --- a/src/lerobot/annotations/steerable_pipeline/modules/plan_subtasks_memory.py +++ b/src/lerobot/annotations/steerable_pipeline/modules/plan_subtasks_memory.py @@ -121,14 +121,37 @@ class PlanSubtasksMemoryModule: record: EpisodeRecord, staging: EpisodeStaging, interjection_times: Sequence[float], + interjection_texts: Sequence[str] | None = None, ) -> None: - """Append additional ``plan`` rows at every interjection timestamp.""" + """Append additional ``plan`` rows at every interjection timestamp. + + Plans refresh ONLY on user interjections — subtask generation + runs ~1 Hz at inference, but plan re-emission is event-driven. + Now also forwards the interjection's own text into the prompt so + the refreshed plan can actually reflect the user's correction + (the previous version told the model "an interjection happened" + without telling it what the user said). + """ existing = staging.read("module_1") spans = self._reconstruct_subtasks_from_rows(existing) + already_planned: set[float] = { + float(r["timestamp"]) for r in existing if r.get("style") == "plan" + } new_rows = list(existing) - for raw_t in interjection_times: + + texts: list[str | None] = ( + [None] * len(interjection_times) + if interjection_texts is None + else [str(t) if t else None for t in interjection_texts] + ) + for raw_t, inter_text in zip(interjection_times, texts): t = _snap_to_frame(raw_t, record.frame_timestamps) - plan_text = self._generate_plan(record, spans, refresh_t=t) + if t in already_planned: + continue + already_planned.add(t) + plan_text = self._generate_plan( + record, spans, refresh_t=t, interjection=inter_text + ) if plan_text is not None: new_rows.append( { @@ -215,6 +238,7 @@ class PlanSubtasksMemoryModule: subtask_spans: Sequence[dict[str, Any]], *, refresh_t: float | None = None, + interjection: str | None = None, ) -> str | None: if not subtask_spans: return None @@ -225,7 +249,33 @@ class PlanSubtasksMemoryModule: plan_max_steps=self.config.plan_max_steps, ) if refresh_t is not None: - prompt += f"\n\n(This is a plan refresh after a user interjection at t={refresh_t:.2f}s.)\n" + # ``current_subtask`` is the span the refresh time falls into, + # so the model knows where in the demonstration the planner is + # standing when it re-emits. + current_subtask = "" + for span in subtask_spans: + if float(span["start"]) <= refresh_t and ( + "end" not in span or float(span["end"]) > refresh_t + ): + current_subtask = span.get("text", "") + break + if interjection: + prompt += ( + f"\n\n(Plan refresh at t={refresh_t:.2f}s after a user " + f"interjection: {interjection!r}. Current subtask just " + f"before the interjection: {current_subtask!r}. Update " + f"the plan so it reflects the interjection — drop or " + f"reorder steps as needed; do not just restate.)\n" + ) + else: + # Refresh without an interjection text: still tell the model + # where in the episode the plan stands so the re-emission + # is grounded. Should be rare — plan refreshes are + # interjection-driven by design. + prompt += ( + f"\n\n(Plan refresh at t={refresh_t:.2f}s. Current " + f"subtask: {current_subtask!r}.)\n" + ) messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}] result = self.vlm.generate_json([messages])[0] if isinstance(result, dict) and isinstance(result.get("plan"), str): diff --git a/src/lerobot/annotations/steerable_pipeline/prompts/module_2_interjection.txt b/src/lerobot/annotations/steerable_pipeline/prompts/module_2_interjection.txt index 0ecb78f9d..600251516 100644 --- a/src/lerobot/annotations/steerable_pipeline/prompts/module_2_interjection.txt +++ b/src/lerobot/annotations/steerable_pipeline/prompts/module_2_interjection.txt @@ -1,27 +1,34 @@ You are simulating a user mid-episode interruption for a robot doing: "{episode_task}". -Synthesize ONE realistic interruption the user might say at this moment in -the demonstration, plus the robot's verbal acknowledgement. +The images above show roughly the last {window_seconds:.1f} seconds of the +demonstration in chronological order. Read what the robot is actually +doing right now and write an interruption that responds to that exact +visible activity — not a generic one. + +Current subtask the robot is executing: {current_subtask} +Time into episode: {timestamp:.2f}s + +Synthesize ONE realistic interruption the user might say at this moment, +plus the robot's verbal acknowledgement. Context (Hi Robot, Shi 2025) — interjections fall into one of these scenario types: -- negative task: "actually skip X" -- situated correction: "that's not trash" -- specific constraint: "use less salt" -- preference: "could you also do Y" +- negative task: "actually skip X" (where X is the visible current step) +- situated correction: "that's not the right one, use the blue one" +- specific constraint: "be more careful with that one" +- preference: "could you also do Y after this" Interruption rules: -- Must be plausible given the current subtask context. +- Must reference an object, motion, or sub-step that is visible in the + attached frames OR explicitly named in the current subtask. Do not + invent objects that aren't there. - Must change the plan in a non-trivial way (a new constraint, skipped step, or correction). -- One sentence each. - -Current subtask context: {current_subtask} -Time into episode: {timestamp:.2f}s +- One sentence each. Conversational, not robotic. Output strictly valid JSON: {{ - "interjection": "", - "speech": "" + "interjection": "", + "speech": "" }}