refactor(annotate): consolidate Module 1's prompt → VLM → JSON-extract pattern

Five Module 1 sub-prompts (`_derive_task_from_video`,
`_generate_task_rephrasings`, `_generate_subtasks`, `_generate_plan`,
`_generate_memory`) all repeated the same shape:

    result = self.vlm.generate_json([messages])[0]
    if isinstance(result, dict) and isinstance(result.get(<field>), <type>):
        ...

…each spelled with slightly different field names + post-processing.

Three small helpers replace it:

* `_vlm_field(messages, field)` — single VLM call, returns
  ``result[field]`` or ``None``. Centralizes the
  ``generate_json([m])[0]`` + ``isinstance(dict)`` dance.
* `_text_message(text)` — wraps a string in the canonical user-message
  shape every text-only prompt builds inline.
* `_video_message(record, prompt)` — combines the episode video block
  with a prompt; replaces the duplicated video-block construction
  inside `_generate_subtasks` (which previously inlined the same
  ``use_video_url``/``frames_per_second``/``max_video_frames`` branches
  that `_episode_video_block` already implements).

Net -35 LOC. Each call site now is 3-5 lines instead of 10-20. The
public method signatures are unchanged so tests don't move.

Drive-by: `_task_seems_bad` collapsed via SIM103 fix; `zip` in
`run_plan_updates` annotated `strict=True` per ruff B905.

Tests: same 2 pre-existing module-impl failures
(`test_module1_attaches_video_block_to_subtask_prompt`,
`test_module2_mid_episode_emits_paired_interjection_and_speech`) —
they were failing on `origin/feat/language-annotation-pipeline` before
this commit and continue to do so for the same reasons. 61/63 in the
language stack pass; pre-commit clean.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-08 11:29:45 +02:00
parent 3a52a18b0e
commit 088c8371df
@@ -19,9 +19,8 @@ from __future__ import annotations
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any
from pathlib import Path from pathlib import Path
from typing import Any
from ..config import Module1Config from ..config import Module1Config
from ..frames import ( from ..frames import (
@@ -81,9 +80,7 @@ class PlanSubtasksMemoryModule:
# so the policy sees diverse phrasings during training. # so the policy sees diverse phrasings during training.
t0 = float(record.frame_timestamps[0]) if record.frame_timestamps else 0.0 t0 = float(record.frame_timestamps[0]) if record.frame_timestamps else 0.0
if self.config.n_task_rephrasings > 0 and effective_task: if self.config.n_task_rephrasings > 0 and effective_task:
rephrasings = self._generate_task_rephrasings( rephrasings = self._generate_task_rephrasings(effective_task, n=self.config.n_task_rephrasings)
effective_task, n=self.config.n_task_rephrasings
)
# Always include the effective task itself as the first variant # Always include the effective task itself as the first variant
# so the rotation is guaranteed to cover the source-of-truth # so the rotation is guaranteed to cover the source-of-truth
# phrasing, not just synthetic alternatives. # phrasing, not just synthetic alternatives.
@@ -133,9 +130,7 @@ class PlanSubtasksMemoryModule:
for i, span in enumerate(subtask_spans[1:], start=1): for i, span in enumerate(subtask_spans[1:], start=1):
completed = subtask_spans[i - 1]["text"] completed = subtask_spans[i - 1]["text"]
remaining = [s["text"] for s in subtask_spans[i:]] remaining = [s["text"] for s in subtask_spans[i:]]
mem_text = self._generate_memory( mem_text = self._generate_memory(record, prior_memory, completed, remaining, task=effective_task)
record, prior_memory, completed, remaining, task=effective_task
)
if mem_text: if mem_text:
ts = _snap_to_frame(span["start"], record.frame_timestamps) ts = _snap_to_frame(span["start"], record.frame_timestamps)
rows.append( rows.append(
@@ -193,44 +188,50 @@ class PlanSubtasksMemoryModule:
return True return True
if len(task.split()) < int(self.config.derive_task_min_words): if len(task.split()) < int(self.config.derive_task_min_words):
return True return True
if task.lower() in self._PLACEHOLDER_TASKS: return task.lower() in self._PLACEHOLDER_TASKS
return True
return False # ------------------------------------------------------------------
# VLM call helpers (factored out: every Module-1 prompt below follows
# the same "build messages → single VLM call → pull a named field"
# shape, only differing in field name + post-processing).
# ------------------------------------------------------------------
def _vlm_field(self, messages: list[dict[str, Any]], field: str) -> Any:
"""Run a single VLM call and return ``result[field]`` or ``None``.
Centralizes the ``vlm.generate_json([m])[0]`` + ``isinstance(dict)``
dance every prompt-call site needs.
"""
result = self.vlm.generate_json([messages])[0]
if isinstance(result, dict):
return result.get(field)
return None
@staticmethod
def _text_message(text: str) -> list[dict[str, Any]]:
"""One-shot text-only user message wrapped for ``generate_json``."""
return [{"role": "user", "content": [{"type": "text", "text": text}]}]
def _video_message(self, record: EpisodeRecord, prompt: str) -> list[dict[str, Any]]:
"""User message combining the episode video block with ``prompt``."""
content = [*self._episode_video_block(record), {"type": "text", "text": prompt}]
return [{"role": "user", "content": content}]
def _derive_task_from_video(self, record: EpisodeRecord) -> str | None: def _derive_task_from_video(self, record: EpisodeRecord) -> str | None:
"""Ask the VLM "what is this video about" with no task hint at all.""" """Ask the VLM "what is this video about" with no task hint at all."""
prompt = load_prompt("module_1_video_task") text = self._vlm_field(self._video_message(record, load_prompt("module_1_video_task")), "task")
video_block = self._episode_video_block(record) return text.strip() if isinstance(text, str) and text.strip() else None
content = [*video_block, {"type": "text", "text": prompt}]
messages = [{"role": "user", "content": content}]
result = self.vlm.generate_json([messages])[0]
if isinstance(result, dict) and isinstance(result.get("task"), str):
text = result["task"].strip()
if text:
return text
return None
def _generate_task_rephrasings(self, base_task: str, *, n: int) -> list[str]: def _generate_task_rephrasings(self, base_task: str, *, n: int) -> list[str]:
"""Generate ``n`` text-only paraphrases of ``base_task``.""" """Generate ``n`` text-only paraphrases of ``base_task``."""
if n <= 0 or not base_task: if n <= 0 or not base_task:
return [] return []
prompt = load_prompt("module_1_task_rephrasings").format( prompt = load_prompt("module_1_task_rephrasings").format(base_task=base_task, n=n)
base_task=base_task, n=n raw = self._vlm_field(self._text_message(prompt), "rephrasings")
)
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
result = self.vlm.generate_json([messages])[0]
if not isinstance(result, dict):
return []
raw = result.get("rephrasings")
if not isinstance(raw, list): if not isinstance(raw, list):
return [] return []
out: list[str] = [] out = [item.strip().strip('"').strip("'") for item in raw if isinstance(item, str)]
for item in raw: return [s for s in out if s][:n]
if isinstance(item, str):
cleaned = item.strip().strip('"').strip("'")
if cleaned:
out.append(cleaned)
return out[:n]
def _episode_video_block(self, record: EpisodeRecord) -> list[dict[str, Any]]: def _episode_video_block(self, record: EpisodeRecord) -> list[dict[str, Any]]:
"""Same video block ``_generate_subtasks`` builds — extracted helper.""" """Same video block ``_generate_subtasks`` builds — extracted helper."""
@@ -245,9 +246,7 @@ class PlanSubtasksMemoryModule:
else [] else []
) )
episode_duration = record.frame_timestamps[-1] - record.frame_timestamps[0] episode_duration = record.frame_timestamps[-1] - record.frame_timestamps[0]
target_count = max( target_count = max(1, int(round(episode_duration * self.config.frames_per_second)))
1, int(round(episode_duration * self.config.frames_per_second))
)
target_count = min(target_count, self.config.max_video_frames) target_count = min(target_count, self.config.max_video_frames)
video_frames = self.frame_provider.video_for_episode(record, target_count) video_frames = self.frame_provider.video_for_episode(record, target_count)
return to_video_block(video_frames) return to_video_block(video_frames)
@@ -270,9 +269,7 @@ class PlanSubtasksMemoryModule:
""" """
existing = staging.read("module_1") existing = staging.read("module_1")
spans = self._reconstruct_subtasks_from_rows(existing) spans = self._reconstruct_subtasks_from_rows(existing)
already_planned: set[float] = { already_planned: set[float] = {float(r["timestamp"]) for r in existing if r.get("style") == "plan"}
float(r["timestamp"]) for r in existing if r.get("style") == "plan"
}
new_rows = list(existing) new_rows = list(existing)
texts: list[str | None] = ( texts: list[str | None] = (
@@ -280,14 +277,12 @@ class PlanSubtasksMemoryModule:
if interjection_texts is None if interjection_texts is None
else [str(t) if t else None for t in interjection_texts] else [str(t) if t else None for t in interjection_texts]
) )
for raw_t, inter_text in zip(interjection_times, texts): for raw_t, inter_text in zip(interjection_times, texts, strict=True):
t = _snap_to_frame(raw_t, record.frame_timestamps) t = _snap_to_frame(raw_t, record.frame_timestamps)
if t in already_planned: if t in already_planned:
continue continue
already_planned.add(t) already_planned.add(t)
plan_text = self._generate_plan( plan_text = self._generate_plan(record, spans, refresh_t=t, interjection=inter_text)
record, spans, refresh_t=t, interjection=inter_text
)
if plan_text is not None: if plan_text is not None:
new_rows.append( new_rows.append(
{ {
@@ -315,9 +310,7 @@ class PlanSubtasksMemoryModule:
last_t = t last_t = t
return out return out
def _generate_subtasks( def _generate_subtasks(self, record: EpisodeRecord, *, task: str | None = None) -> list[dict[str, Any]]:
self, record: EpisodeRecord, *, task: str | None = None
) -> list[dict[str, Any]]:
if record.row_count == 0 or not record.frame_timestamps: if record.row_count == 0 or not record.frame_timestamps:
return [] return []
episode_duration = record.frame_timestamps[-1] - record.frame_timestamps[0] episode_duration = record.frame_timestamps[-1] - record.frame_timestamps[0]
@@ -327,26 +320,7 @@ class PlanSubtasksMemoryModule:
max_steps=self.config.plan_max_steps, max_steps=self.config.plan_max_steps,
episode_duration=f"{episode_duration:.3f}", episode_duration=f"{episode_duration:.3f}",
) )
if self.config.use_video_url and isinstance(self.frame_provider, VideoFrameProvider): spans = self._vlm_field(self._video_message(record, prompt), "subtasks")
cache_dir = Path(self.frame_provider.root) / ".annotate_staging" / ".video_clips"
clip = episode_clip_path(record, self.frame_provider, cache_dir)
video_block = (
to_video_url_block(f"file://{clip}", fps=self.config.use_video_url_fps)
if clip is not None
else []
)
else:
target_count = max(
1,
int(round(episode_duration * self.config.frames_per_second)),
)
target_count = min(target_count, self.config.max_video_frames)
video_frames = self.frame_provider.video_for_episode(record, target_count)
video_block = to_video_block(video_frames)
content = [*video_block, {"type": "text", "text": prompt}]
messages = [{"role": "user", "content": content}]
result = self.vlm.generate_json([messages])[0]
spans = result.get("subtasks") if isinstance(result, dict) else None
if not spans: if not spans:
return [] return []
# clamp to [t0, t_last] and sort # clamp to [t0, t_last] and sort
@@ -411,15 +385,9 @@ class PlanSubtasksMemoryModule:
# where in the episode the plan stands so the re-emission # where in the episode the plan stands so the re-emission
# is grounded. Should be rare — plan refreshes are # is grounded. Should be rare — plan refreshes are
# interjection-driven by design. # interjection-driven by design.
prompt += ( prompt += f"\n\n(Plan refresh at t={refresh_t:.2f}s. Current subtask: {current_subtask!r}.)\n"
f"\n\n(Plan refresh at t={refresh_t:.2f}s. Current " plan = self._vlm_field(self._text_message(prompt), "plan")
f"subtask: {current_subtask!r}.)\n" return plan.strip() if isinstance(plan, str) else None
)
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
result = self.vlm.generate_json([messages])[0]
if isinstance(result, dict) and isinstance(result.get("plan"), str):
return result["plan"].strip()
return None
def _generate_memory( def _generate_memory(
self, self,
@@ -436,8 +404,5 @@ class PlanSubtasksMemoryModule:
completed_subtask=completed, completed_subtask=completed,
remaining_subtasks=", ".join(remaining) if remaining else "(none)", remaining_subtasks=", ".join(remaining) if remaining else "(none)",
) )
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}] memory = self._vlm_field(self._text_message(prompt), "memory")
result = self.vlm.generate_json([messages])[0] return memory.strip() if isinstance(memory, str) else ""
if isinstance(result, dict) and isinstance(result.get("memory"), str):
return result["memory"].strip()
return ""