refactor(annotate): consolidate Module 1's prompt → VLM → JSON-extract pattern

Five Module 1 sub-prompts (`_derive_task_from_video`,
`_generate_task_rephrasings`, `_generate_subtasks`, `_generate_plan`,
`_generate_memory`) all repeated the same shape:

    result = self.vlm.generate_json([messages])[0]
    if isinstance(result, dict) and isinstance(result.get(<field>), <type>):
        ...

…each spelled with slightly different field names + post-processing.

Three small helpers replace it:

* `_vlm_field(messages, field)` — single VLM call, returns
  ``result[field]`` or ``None``. Centralizes the
  ``generate_json([m])[0]`` + ``isinstance(dict)`` dance.
* `_text_message(text)` — wraps a string in the canonical user-message
  shape every text-only prompt builds inline.
* `_video_message(record, prompt)` — combines the episode video block
  with a prompt; replaces the duplicated video-block construction
  inside `_generate_subtasks` (which previously inlined the same
  ``use_video_url``/``frames_per_second``/``max_video_frames`` branches
  that `_episode_video_block` already implements).

Net -35 LOC. Each call site now is 3-5 lines instead of 10-20. The
public method signatures are unchanged so tests don't move.

Drive-by: `_task_seems_bad` collapsed via SIM103 fix; `zip` in
`run_plan_updates` annotated `strict=True` per ruff B905.

Tests: same 2 pre-existing module-impl failures
(`test_module1_attaches_video_block_to_subtask_prompt`,
`test_module2_mid_episode_emits_paired_interjection_and_speech`) —
they were failing on `origin/feat/language-annotation-pipeline` before
this commit and continue to do so for the same reasons. 61/63 in the
language stack pass; pre-commit clean.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-08 11:29:45 +02:00
parent 3a52a18b0e
commit 088c8371df
@@ -19,9 +19,8 @@ from __future__ import annotations
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import Any
from pathlib import Path
from typing import Any
from ..config import Module1Config
from ..frames import (
@@ -81,9 +80,7 @@ class PlanSubtasksMemoryModule:
# so the policy sees diverse phrasings during training.
t0 = float(record.frame_timestamps[0]) if record.frame_timestamps else 0.0
if self.config.n_task_rephrasings > 0 and effective_task:
rephrasings = self._generate_task_rephrasings(
effective_task, n=self.config.n_task_rephrasings
)
rephrasings = self._generate_task_rephrasings(effective_task, n=self.config.n_task_rephrasings)
# Always include the effective task itself as the first variant
# so the rotation is guaranteed to cover the source-of-truth
# phrasing, not just synthetic alternatives.
@@ -133,9 +130,7 @@ class PlanSubtasksMemoryModule:
for i, span in enumerate(subtask_spans[1:], start=1):
completed = subtask_spans[i - 1]["text"]
remaining = [s["text"] for s in subtask_spans[i:]]
mem_text = self._generate_memory(
record, prior_memory, completed, remaining, task=effective_task
)
mem_text = self._generate_memory(record, prior_memory, completed, remaining, task=effective_task)
if mem_text:
ts = _snap_to_frame(span["start"], record.frame_timestamps)
rows.append(
@@ -193,44 +188,50 @@ class PlanSubtasksMemoryModule:
return True
if len(task.split()) < int(self.config.derive_task_min_words):
return True
if task.lower() in self._PLACEHOLDER_TASKS:
return True
return False
return task.lower() in self._PLACEHOLDER_TASKS
# ------------------------------------------------------------------
# VLM call helpers (factored out: every Module-1 prompt below follows
# the same "build messages → single VLM call → pull a named field"
# shape, only differing in field name + post-processing).
# ------------------------------------------------------------------
def _vlm_field(self, messages: list[dict[str, Any]], field: str) -> Any:
"""Run a single VLM call and return ``result[field]`` or ``None``.
Centralizes the ``vlm.generate_json([m])[0]`` + ``isinstance(dict)``
dance every prompt-call site needs.
"""
result = self.vlm.generate_json([messages])[0]
if isinstance(result, dict):
return result.get(field)
return None
@staticmethod
def _text_message(text: str) -> list[dict[str, Any]]:
"""One-shot text-only user message wrapped for ``generate_json``."""
return [{"role": "user", "content": [{"type": "text", "text": text}]}]
def _video_message(self, record: EpisodeRecord, prompt: str) -> list[dict[str, Any]]:
"""User message combining the episode video block with ``prompt``."""
content = [*self._episode_video_block(record), {"type": "text", "text": prompt}]
return [{"role": "user", "content": content}]
def _derive_task_from_video(self, record: EpisodeRecord) -> str | None:
"""Ask the VLM "what is this video about" with no task hint at all."""
prompt = load_prompt("module_1_video_task")
video_block = self._episode_video_block(record)
content = [*video_block, {"type": "text", "text": prompt}]
messages = [{"role": "user", "content": content}]
result = self.vlm.generate_json([messages])[0]
if isinstance(result, dict) and isinstance(result.get("task"), str):
text = result["task"].strip()
if text:
return text
return None
text = self._vlm_field(self._video_message(record, load_prompt("module_1_video_task")), "task")
return text.strip() if isinstance(text, str) and text.strip() else None
def _generate_task_rephrasings(self, base_task: str, *, n: int) -> list[str]:
"""Generate ``n`` text-only paraphrases of ``base_task``."""
if n <= 0 or not base_task:
return []
prompt = load_prompt("module_1_task_rephrasings").format(
base_task=base_task, n=n
)
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
result = self.vlm.generate_json([messages])[0]
if not isinstance(result, dict):
return []
raw = result.get("rephrasings")
prompt = load_prompt("module_1_task_rephrasings").format(base_task=base_task, n=n)
raw = self._vlm_field(self._text_message(prompt), "rephrasings")
if not isinstance(raw, list):
return []
out: list[str] = []
for item in raw:
if isinstance(item, str):
cleaned = item.strip().strip('"').strip("'")
if cleaned:
out.append(cleaned)
return out[:n]
out = [item.strip().strip('"').strip("'") for item in raw if isinstance(item, str)]
return [s for s in out if s][:n]
def _episode_video_block(self, record: EpisodeRecord) -> list[dict[str, Any]]:
"""Same video block ``_generate_subtasks`` builds — extracted helper."""
@@ -245,9 +246,7 @@ class PlanSubtasksMemoryModule:
else []
)
episode_duration = record.frame_timestamps[-1] - record.frame_timestamps[0]
target_count = max(
1, int(round(episode_duration * self.config.frames_per_second))
)
target_count = max(1, int(round(episode_duration * self.config.frames_per_second)))
target_count = min(target_count, self.config.max_video_frames)
video_frames = self.frame_provider.video_for_episode(record, target_count)
return to_video_block(video_frames)
@@ -270,9 +269,7 @@ class PlanSubtasksMemoryModule:
"""
existing = staging.read("module_1")
spans = self._reconstruct_subtasks_from_rows(existing)
already_planned: set[float] = {
float(r["timestamp"]) for r in existing if r.get("style") == "plan"
}
already_planned: set[float] = {float(r["timestamp"]) for r in existing if r.get("style") == "plan"}
new_rows = list(existing)
texts: list[str | None] = (
@@ -280,14 +277,12 @@ class PlanSubtasksMemoryModule:
if interjection_texts is None
else [str(t) if t else None for t in interjection_texts]
)
for raw_t, inter_text in zip(interjection_times, texts):
for raw_t, inter_text in zip(interjection_times, texts, strict=True):
t = _snap_to_frame(raw_t, record.frame_timestamps)
if t in already_planned:
continue
already_planned.add(t)
plan_text = self._generate_plan(
record, spans, refresh_t=t, interjection=inter_text
)
plan_text = self._generate_plan(record, spans, refresh_t=t, interjection=inter_text)
if plan_text is not None:
new_rows.append(
{
@@ -315,9 +310,7 @@ class PlanSubtasksMemoryModule:
last_t = t
return out
def _generate_subtasks(
self, record: EpisodeRecord, *, task: str | None = None
) -> list[dict[str, Any]]:
def _generate_subtasks(self, record: EpisodeRecord, *, task: str | None = None) -> list[dict[str, Any]]:
if record.row_count == 0 or not record.frame_timestamps:
return []
episode_duration = record.frame_timestamps[-1] - record.frame_timestamps[0]
@@ -327,26 +320,7 @@ class PlanSubtasksMemoryModule:
max_steps=self.config.plan_max_steps,
episode_duration=f"{episode_duration:.3f}",
)
if self.config.use_video_url and isinstance(self.frame_provider, VideoFrameProvider):
cache_dir = Path(self.frame_provider.root) / ".annotate_staging" / ".video_clips"
clip = episode_clip_path(record, self.frame_provider, cache_dir)
video_block = (
to_video_url_block(f"file://{clip}", fps=self.config.use_video_url_fps)
if clip is not None
else []
)
else:
target_count = max(
1,
int(round(episode_duration * self.config.frames_per_second)),
)
target_count = min(target_count, self.config.max_video_frames)
video_frames = self.frame_provider.video_for_episode(record, target_count)
video_block = to_video_block(video_frames)
content = [*video_block, {"type": "text", "text": prompt}]
messages = [{"role": "user", "content": content}]
result = self.vlm.generate_json([messages])[0]
spans = result.get("subtasks") if isinstance(result, dict) else None
spans = self._vlm_field(self._video_message(record, prompt), "subtasks")
if not spans:
return []
# clamp to [t0, t_last] and sort
@@ -411,15 +385,9 @@ class PlanSubtasksMemoryModule:
# where in the episode the plan stands so the re-emission
# is grounded. Should be rare — plan refreshes are
# interjection-driven by design.
prompt += (
f"\n\n(Plan refresh at t={refresh_t:.2f}s. Current "
f"subtask: {current_subtask!r}.)\n"
)
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
result = self.vlm.generate_json([messages])[0]
if isinstance(result, dict) and isinstance(result.get("plan"), str):
return result["plan"].strip()
return None
prompt += f"\n\n(Plan refresh at t={refresh_t:.2f}s. Current subtask: {current_subtask!r}.)\n"
plan = self._vlm_field(self._text_message(prompt), "plan")
return plan.strip() if isinstance(plan, str) else None
def _generate_memory(
self,
@@ -436,8 +404,5 @@ class PlanSubtasksMemoryModule:
completed_subtask=completed,
remaining_subtasks=", ".join(remaining) if remaining else "(none)",
)
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
result = self.vlm.generate_json([messages])[0]
if isinstance(result, dict) and isinstance(result.get("memory"), str):
return result["memory"].strip()
return ""
memory = self._vlm_field(self._text_message(prompt), "memory")
return memory.strip() if isinstance(memory, str) else ""