fix(annotate): ground interjections in video + propagate text to plan refresh

qwen36moe-10 showed three Module-2 / plan-refresh quality issues that
are not architecture problems — they're prompt-grounding bugs:

1. Interjection prompt passed ``current_subtask = record.episode_task``
   (the WHOLE-episode task), not the actual subtask in force at the
   chosen timestamp. The VLM had no signal about what was visible at
   that moment, so its interjections were generic ("actually skip X"
   where X had nothing to do with the visible activity).

2. Interjection prompt only attached a single frame
   (``frames_at(record, [t_snap])``). With one frozen image the VLM
   couldn't read the ongoing motion. Module 1 already gets the whole
   episode video for subtask decomposition, which is why subtasks are
   well-grounded; Module 2 was the outlier.

3. The plan-refresh prompt told the model "a plan refresh after a user
   interjection at t=X.YZs" but never showed it the interjection
   *text*. So the refreshed plan couldn't actually reflect the user's
   correction — at best it recombined the same step list.

Fix:

- ``interjections_and_speech.py``: Module 2 reads Module 1's subtask
  rows from the same staging tree (executor orders module_1 → module_2
  so they're already there) and resolves the actual ``current_subtask``
  at each chosen timestamp. Pulls a small clip
  (``interjection_window_seconds`` × ``interjection_window_frames``,
  defaulting to 4 frames over the leading 2 s) instead of one frame.
  Drops the silently-zeroing ``len(candidate_ts) // 4`` cap on the
  interjection count.

- ``module_2_interjection.txt``: prompt is rewritten to reference the
  multi-frame visual context and require the interjection to mention
  something visible OR named in the current subtask, not invented.

- ``plan_subtasks_memory.py``: ``run_plan_updates`` now accepts and
  threads through interjection texts. ``_generate_plan(refresh_t,
  interjection)`` injects both the current subtask AND the interjection
  text into the prompt so the refreshed plan can drop / reorder /
  constrain steps to match the user's correction. (Plan still refreshes
  ONLY at user interjections — subtask generation runs ~1 Hz at
  inference, plan re-emission is event-driven.)

- ``executor.py``: forwards ``interjection_texts`` alongside
  ``interjection_times`` to ``run_plan_updates``.

- ``Module2Config``: bumps ``max_interjections_per_episode`` default
  from 1 to 3 and exposes ``interjection_window_seconds`` /
  ``interjection_window_frames``.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-04-30 16:11:10 +02:00
parent b71e10da6b
commit 3434d2ef22
5 changed files with 177 additions and 27 deletions
@@ -62,8 +62,23 @@ class Module2Config:
"""Module 2 hyperparameters: interjections + paired speech."""
enabled: bool = True
max_interjections_per_episode: int = 1
max_interjections_per_episode: int = 3
"""Number of mid-episode interjections to generate per episode. Each
creates a paired ``(interjection, speech)`` event row plus triggers a
``plan`` refresh at the same timestamp via Module 1. Bumped from the
original ``1`` after qwen36moe-10 showed plan/interjection coverage
was too sparse for Hi Robot-style training."""
interjection_min_t: float = 2.0
interjection_window_seconds: float = 2.0
"""How many seconds of video to attach to the interjection prompt as
visual context. Without this the VLM only sees a single frozen frame
and writes generic interjections that aren't grounded in the actual
motion happening at the chosen timestamp."""
interjection_window_frames: int = 4
"""How many frames to sample over ``interjection_window_seconds``.
Default 4 ⇒ ~0.5 fps over the leading 2 seconds — enough for the
model to read the ongoing motion, cheap enough to keep prompt size
bounded for the 32k context."""
@dataclass
@@ -110,7 +110,9 @@ class Executor:
# Phase 1: Module 1 (plan + subtasks + memory)
phases.append(self._run_module_phase("module_1", records, staging_dir, self.module_1))
# Phase 2: Module 2 (interjections + speech)
# Phase 2: Module 2 (interjections + speech). Module 2 reads
# Module 1's subtask rows from the same staging tree to ground
# the interjection prompt in the correct local subtask.
phases.append(self._run_module_phase("module_2", records, staging_dir, self.module_2))
# Phase 3: Module 1 plan-update pass at interjection timestamps.
phases.append(self._run_plan_update_phase(records, staging_dir))
@@ -198,10 +200,16 @@ class Executor:
processed = 0
for record in records:
staging = EpisodeStaging(staging_dir, record.episode_index)
interjection_times = [
row["timestamp"] for row in staging.read("module_2") if row.get("style") == "interjection"
interjection_rows = [
row
for row in staging.read("module_2")
if row.get("style") == "interjection"
]
interjection_times = [float(row["timestamp"]) for row in interjection_rows]
interjection_texts = [str(row.get("content") or "") for row in interjection_rows]
if interjection_times:
self.module_1.run_plan_updates(record, staging, interjection_times)
self.module_1.run_plan_updates(
record, staging, interjection_times, interjection_texts
)
processed += 1
return PhaseResult(name="module_1_plan_update", episodes_processed=processed, episodes_skipped=0)
@@ -72,9 +72,37 @@ class InterjectionsAndSpeechModule:
initial = self._initial_speech(record)
if initial:
rows.append(speech_atom(t0, initial))
rows.extend(self._mid_episode_interjections(record))
# Pull Module 1's subtask spans for this episode so the
# interjection prompt can ground itself in the actual current
# subtask at each chosen timestamp. Module 1 ran first.
subtask_spans = self._read_subtask_spans(staging)
rows.extend(self._mid_episode_interjections(record, subtask_spans))
staging.write("module_2", rows)
@staticmethod
def _read_subtask_spans(staging: EpisodeStaging) -> list[dict[str, Any]]:
rows = [r for r in staging.read("module_1") if r.get("style") == "subtask"]
rows.sort(key=lambda r: float(r["timestamp"]))
spans: list[dict[str, Any]] = []
last_t: float | None = None
for r in rows:
t = float(r["timestamp"])
if last_t is not None and spans:
spans[-1]["end"] = t
spans.append({"text": r.get("content") or "", "start": t, "end": t})
last_t = t
return spans
@staticmethod
def _subtask_at(spans: Sequence[dict[str, Any]], t: float) -> str | None:
current: str | None = None
for span in spans:
if float(span["start"]) <= t:
current = span.get("text")
else:
break
return current
def _initial_speech(self, record: EpisodeRecord) -> str | None:
prompt = load_prompt("module_2_initial_speech").format(
episode_task=record.episode_task,
@@ -87,7 +115,11 @@ class InterjectionsAndSpeechModule:
return text
return None
def _mid_episode_interjections(self, record: EpisodeRecord) -> list[dict[str, Any]]:
def _mid_episode_interjections(
self,
record: EpisodeRecord,
subtask_spans: Sequence[dict[str, Any]],
) -> list[dict[str, Any]]:
if self.config.max_interjections_per_episode <= 0:
return []
# Deterministic per-episode RNG so reruns are stable across SLURM jobs.
@@ -95,20 +127,30 @@ class InterjectionsAndSpeechModule:
candidate_ts = [t for t in record.frame_timestamps if t >= self.config.interjection_min_t]
if not candidate_ts:
return []
n = min(self.config.max_interjections_per_episode, len(candidate_ts) // 4)
# Pick at most ``max_interjections_per_episode`` distinct timestamps.
# Previously capped at ``len(candidate_ts) // 4`` — that floor was
# only relevant for very short episodes; for any real ~20-30s
# episode it had no effect, but it silently set the count to 0 on
# short fixtures. Just take ``min(max, len)`` directly.
n = min(self.config.max_interjections_per_episode, len(candidate_ts))
if n <= 0:
return []
chosen = sorted(rng.sample(candidate_ts, n))
out: list[dict[str, Any]] = []
for t in chosen:
t_snap = _snap_to_frame(t, record.frame_timestamps)
current_subtask = record.episode_task
window_ts = self._window_timestamps(t_snap, record.frame_timestamps)
current_subtask = (
self._subtask_at(subtask_spans, t_snap) or record.episode_task
)
prompt = load_prompt("module_2_interjection").format(
episode_task=record.episode_task,
current_subtask=current_subtask,
timestamp=t_snap,
window_seconds=self.config.interjection_window_seconds,
)
images = self.frame_provider.frames_at(record, [t_snap])
images = self.frame_provider.frames_at(record, window_ts)
content = [*to_image_blocks(images), {"type": "text", "text": prompt}]
messages = [{"role": "user", "content": content}]
result = self.vlm.generate_json([messages])[0]
@@ -131,3 +173,31 @@ class InterjectionsAndSpeechModule:
)
out.append(speech_atom(t_snap, speech_text.strip()))
return out
def _window_timestamps(
self, t_anchor: float, frame_timestamps: Sequence[float]
) -> list[float]:
"""Return a small set of frame timestamps spanning the lead-up to ``t``.
The VLM receives roughly ``num_frames`` frames over the
``window_seconds`` immediately before ``t_anchor``, snapped to
actual source frame timestamps. This gives the interjection
prompt enough temporal context to read what's visibly happening
instead of looking at one frozen frame.
"""
if not frame_timestamps:
return [t_anchor]
n = max(1, int(self.config.interjection_window_frames))
if n == 1:
return [t_anchor]
window = float(self.config.interjection_window_seconds)
step = window / max(1, n - 1)
targets = [t_anchor - step * (n - 1 - i) for i in range(n)]
snapped: list[float] = []
seen: set[float] = set()
for tgt in targets:
t = _snap_to_frame(max(0.0, tgt), frame_timestamps)
if t not in seen:
seen.add(t)
snapped.append(t)
return snapped or [t_anchor]
@@ -121,14 +121,37 @@ class PlanSubtasksMemoryModule:
record: EpisodeRecord,
staging: EpisodeStaging,
interjection_times: Sequence[float],
interjection_texts: Sequence[str] | None = None,
) -> None:
"""Append additional ``plan`` rows at every interjection timestamp."""
"""Append additional ``plan`` rows at every interjection timestamp.
Plans refresh ONLY on user interjections — subtask generation
runs ~1 Hz at inference, but plan re-emission is event-driven.
Now also forwards the interjection's own text into the prompt so
the refreshed plan can actually reflect the user's correction
(the previous version told the model "an interjection happened"
without telling it what the user said).
"""
existing = staging.read("module_1")
spans = self._reconstruct_subtasks_from_rows(existing)
already_planned: set[float] = {
float(r["timestamp"]) for r in existing if r.get("style") == "plan"
}
new_rows = list(existing)
for raw_t in interjection_times:
texts: list[str | None] = (
[None] * len(interjection_times)
if interjection_texts is None
else [str(t) if t else None for t in interjection_texts]
)
for raw_t, inter_text in zip(interjection_times, texts):
t = _snap_to_frame(raw_t, record.frame_timestamps)
plan_text = self._generate_plan(record, spans, refresh_t=t)
if t in already_planned:
continue
already_planned.add(t)
plan_text = self._generate_plan(
record, spans, refresh_t=t, interjection=inter_text
)
if plan_text is not None:
new_rows.append(
{
@@ -215,6 +238,7 @@ class PlanSubtasksMemoryModule:
subtask_spans: Sequence[dict[str, Any]],
*,
refresh_t: float | None = None,
interjection: str | None = None,
) -> str | None:
if not subtask_spans:
return None
@@ -225,7 +249,33 @@ class PlanSubtasksMemoryModule:
plan_max_steps=self.config.plan_max_steps,
)
if refresh_t is not None:
prompt += f"\n\n(This is a plan refresh after a user interjection at t={refresh_t:.2f}s.)\n"
# ``current_subtask`` is the span the refresh time falls into,
# so the model knows where in the demonstration the planner is
# standing when it re-emits.
current_subtask = ""
for span in subtask_spans:
if float(span["start"]) <= refresh_t and (
"end" not in span or float(span["end"]) > refresh_t
):
current_subtask = span.get("text", "")
break
if interjection:
prompt += (
f"\n\n(Plan refresh at t={refresh_t:.2f}s after a user "
f"interjection: {interjection!r}. Current subtask just "
f"before the interjection: {current_subtask!r}. Update "
f"the plan so it reflects the interjection — drop or "
f"reorder steps as needed; do not just restate.)\n"
)
else:
# Refresh without an interjection text: still tell the model
# where in the episode the plan stands so the re-emission
# is grounded. Should be rare — plan refreshes are
# interjection-driven by design.
prompt += (
f"\n\n(Plan refresh at t={refresh_t:.2f}s. Current "
f"subtask: {current_subtask!r}.)\n"
)
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
result = self.vlm.generate_json([messages])[0]
if isinstance(result, dict) and isinstance(result.get("plan"), str):
@@ -1,27 +1,34 @@
You are simulating a user mid-episode interruption for a robot doing:
"{episode_task}".
Synthesize ONE realistic interruption the user might say at this moment in
the demonstration, plus the robot's verbal acknowledgement.
The images above show roughly the last {window_seconds:.1f} seconds of the
demonstration in chronological order. Read what the robot is actually
doing right now and write an interruption that responds to that exact
visible activity — not a generic one.
Current subtask the robot is executing: {current_subtask}
Time into episode: {timestamp:.2f}s
Synthesize ONE realistic interruption the user might say at this moment,
plus the robot's verbal acknowledgement.
Context (Hi Robot, Shi 2025) — interjections fall into one of these
scenario types:
- negative task: "actually skip X"
- situated correction: "that's not trash"
- specific constraint: "use less salt"
- preference: "could you also do Y"
- negative task: "actually skip X" (where X is the visible current step)
- situated correction: "that's not the right one, use the blue one"
- specific constraint: "be more careful with that one"
- preference: "could you also do Y after this"
Interruption rules:
- Must be plausible given the current subtask context.
- Must reference an object, motion, or sub-step that is visible in the
attached frames OR explicitly named in the current subtask. Do not
invent objects that aren't there.
- Must change the plan in a non-trivial way (a new constraint, skipped
step, or correction).
- One sentence each.
Current subtask context: {current_subtask}
Time into episode: {timestamp:.2f}s
- One sentence each. Conversational, not robotic.
Output strictly valid JSON:
{{
"interjection": "<single sentence the user says>",
"speech": "<single sentence the robot speaks back>"
"interjection": "<single sentence the user says about what is visible right now>",
"speech": "<single sentence the robot speaks back, acknowledging the change>"
}}