annotate: windowed subtask generation for constant temporal density

Long episodes no longer get sparse subtasks. Previously a long episode
was subsampled to max_video_frames=32 across its whole duration (~1
frame/4s for a 2-min clip). New opt-in windowing keeps a CONSTANT
frames_per_second density by splitting the episode into fixed-length
windows and running the subtask chain per window.

New PlanConfig.subtask_window_seconds (default 0.0 = off). When > 0 and
the episode is longer than one window:
  * episode is split into consecutive [w0, w1] windows of this length
  * each window's frames are sampled at frames_per_second (so a 32s
    window at 1 fps = 32 frames, filling but not exceeding the per-call
    context budget)
  * the full describe -> segment -> verify chain runs PER window, in
    window-relative time [0, L]; spans are offset back to absolute
  * all windows' spans are merged, frame-snap-deduped, and stitched into
    one contiguous whole-episode cover

Implementation:
  * _episode_video_block / _video_message / _describe_episode /
    _verify_subtasks gain an optional window=(w0,w1); when set they
    embed frames sampled in that absolute range at frames_per_second
    (video_url path skipped — it's whole-episode).
  * _clean_spans gains bounds= (override clamp range, for window-relative
    spans) and dedupe= (skip frame-snap until the merged absolute set).
  * new _generate_subtasks_windowed + _subtasks_for_window orchestrate
    the loop; _generate_subtasks branches to them when window_s > 0.

run_hf_job.py: --plan.subtask_window_seconds=32 (32s windows at 1 fps).
Cost scales with episode length (chain calls × ceil(duration/window)).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-06-02 16:26:14 +02:00
parent 3236c6ee4a
commit 518e191337
3 changed files with 187 additions and 17 deletions
+6
View File
@@ -70,6 +70,12 @@ CMD = (
"--plan.use_video_url=false "
"--plan.frames_per_second=1.0 "
"--plan.max_video_frames=32 "
# Constant 1 fps density via windowing: episodes longer than 32s are
# split into 32-second windows (each 32 frames @ 1 fps, fits context),
# so long episodes get MORE subtasks instead of a sparser whole-episode
# view. describe->segment->verify runs per window; spans are merged +
# stitched to a contiguous whole-episode cover. 0 disables.
"--plan.subtask_window_seconds=32 "
# IMPORTANT for RoboCasa: the dataset's task string ("Navigate to the
# stove", "Pick the mug...") is authoritative and is what eval uses.
# ``derive_task_from_video=off`` keeps that canonical task driving
@@ -58,6 +58,19 @@ class PlanConfig:
frames_per_second: float = 1.0
max_video_frames: int = 32
# Windowed subtask generation for CONSTANT temporal density. When > 0
# and an episode is longer than this many seconds, the plan module
# processes the episode in consecutive windows of this length, each
# sampled at ``frames_per_second``, instead of subsampling the whole
# episode to ``max_video_frames`` (which makes long episodes sparse).
# The describe -> segment -> verify chain runs per window; results are
# offset to absolute time, merged, and stitched into a contiguous
# whole-episode cover. Cost scales with episode length (≈ chain calls
# × ceil(duration / window)). Set to ~max_video_frames / frames_per_
# second (e.g. 32s at 1 fps) so each window fills — but never exceeds —
# the per-call frame budget. 0 disables (single whole-episode call).
subtask_window_seconds: float = 0.0
min_subtask_seconds: float = 1.5
plan_max_steps: int = 8
@@ -272,9 +272,14 @@ class PlanSubtasksMemoryModule:
"""One-shot text-only user message wrapped for ``generate_json``."""
return [{"role": "user", "content": [{"type": "text", "text": text}]}]
def _video_message(self, record: EpisodeRecord, prompt: str) -> list[dict[str, Any]]:
"""User message combining the episode video block with ``prompt``."""
content = [*self._episode_video_block(record), {"type": "text", "text": prompt}]
def _video_message(
self,
record: EpisodeRecord,
prompt: str,
window: tuple[float, float] | None = None,
) -> list[dict[str, Any]]:
"""User message combining the (optionally windowed) video block with ``prompt``."""
content = [*self._episode_video_block(record, window=window), {"type": "text", "text": prompt}]
return [{"role": "user", "content": content}]
def _derive_task_from_video(self, record: EpisodeRecord) -> str | None:
@@ -442,8 +447,10 @@ class PlanSubtasksMemoryModule:
flat.append(key)
return flat
def _episode_video_block(self, record: EpisodeRecord) -> list[dict[str, Any]]:
"""Same video block ``_generate_subtasks`` builds — extracted helper.
def _episode_video_block(
self, record: EpisodeRecord, window: tuple[float, float] | None = None
) -> list[dict[str, Any]]:
"""Video block for the segmentation / describe / verify prompts.
Always returns a block that actually carries the video. When
``use_video_url`` is set we try the server-side ``video_url``
@@ -452,9 +459,29 @@ class PlanSubtasksMemoryModule:
block — an empty block would leave the VLM with no visual
grounding at all and it would hallucinate subtasks purely from
the task text.
When ``window=(w0, w1)`` is given (windowed subtask generation,
``subtask_window_seconds > 0``), embed frames sampled at the FIXED
``frames_per_second`` rate within ``[w0, w1]`` — constant temporal
density regardless of episode length, so long episodes are split
into windows rather than subsampled to a sparse 32-frame whole-
episode view. The ``video_url`` path is skipped for windows (it is
a whole-episode clip). ``max_video_frames`` still caps each window
as a context-budget safety net.
"""
if not record.frame_timestamps:
return []
if window is not None:
w0, w1 = float(window[0]), float(window[1])
dur = max(0.0, w1 - w0)
n = max(1, int(round(dur * self.config.frames_per_second)) + 1)
n = min(n, self.config.max_video_frames)
if n <= 1 or dur <= 0.0:
timestamps = [0.5 * (w0 + w1)]
else:
step = dur / (n - 1)
timestamps = [w0 + i * step for i in range(n)]
return to_video_block(self.frame_provider.frames_at(record, timestamps))
if self.config.use_video_url and isinstance(self.frame_provider, VideoFrameProvider):
cache_dir = Path(self.frame_provider.root) / ".annotate_staging" / ".video_clips"
clip = self.frame_provider.episode_clip_path(record, cache_dir)
@@ -541,6 +568,17 @@ class PlanSubtasksMemoryModule:
episode_duration = record.frame_timestamps[-1] - record.frame_timestamps[0]
effective_task = task if task is not None else record.episode_task
# ---- Windowed path (constant temporal density) ---------------
# When ``subtask_window_seconds > 0`` and the episode is longer
# than one window, process the episode in fixed-length windows so
# the VLM always sees ``frames_per_second`` density (instead of a
# sparse 32-frame whole-episode view). Each window runs the full
# describe -> segment -> verify chain on its own frames; results
# are merged + stitched into a contiguous whole-episode cover.
window_s = float(getattr(self.config, "subtask_window_seconds", 0.0) or 0.0)
if window_s > 0.0 and episode_duration > window_s:
return self._generate_subtasks_windowed(record, effective_task, window_s)
# ---- Pass 1 (optional): grounding description ----------------
observation_block = ""
if getattr(self.config, "subtask_describe_first", False):
@@ -586,6 +624,91 @@ class PlanSubtasksMemoryModule:
return cleaned
def _generate_subtasks_windowed(
self, record: EpisodeRecord, task: str, window_s: float
) -> list[dict[str, Any]]:
"""Subtask generation in fixed-length windows at constant fps.
Splits ``[t0, t_last]`` into consecutive windows of ``window_s``
seconds, runs the describe -> segment -> verify chain on each
window's own frames (sampled at ``frames_per_second``), offsets
each window's spans back to absolute episode time, then merges +
stitches into a contiguous whole-episode cover.
"""
t0 = float(record.frame_timestamps[0])
t_last = float(record.frame_timestamps[-1])
all_spans: list[dict[str, Any]] = []
w0 = t0
n_windows = 0
while w0 < t_last - 1e-6:
w1 = min(w0 + window_s, t_last)
all_spans.extend(self._subtasks_for_window(record, task, w0, w1))
n_windows += 1
w0 = w1
logger.info(
"episode %d: windowed subtask gen over %d window(s) of %.1fs -> %d raw spans",
record.episode_index,
n_windows,
window_s,
len(all_spans),
)
# Merge across windows: clamp to the absolute episode, sort, and
# frame-snap to distinct starts (handles any boundary collisions).
cleaned = self._clean_spans(all_spans, record)
if not cleaned:
return []
return self._stitch_full_coverage(cleaned, record)
def _subtasks_for_window(
self, record: EpisodeRecord, task: str, w0: float, w1: float
) -> list[dict[str, Any]]:
"""Run describe -> segment -> verify on one ``[w0, w1]`` window.
The model works in window-RELATIVE time ``[0, L]`` (it perceives
the window as a clip starting at 0); spans are offset back to
absolute ``[w0, w1]`` before returning.
"""
window = (w0, w1)
win_len = max(0.0, w1 - w0)
observation_block = ""
if getattr(self.config, "subtask_describe_first", False):
description = self._describe_episode(record, task, window=window)
if description:
observation_block = (
"You watched this video clip and described, chronologically, "
"ONLY what the robot actually does:\n"
f'"""{description}"""\n\n'
"Segment THAT grounded description (cross-checked against "
"the clip) into atomic subtasks. Do not introduce any "
"action that is not in your description above.\n\n"
)
prompt = load_prompt("module_1_subtasks").format(
episode_task=task,
min_subtask_seconds=self.config.min_subtask_seconds,
max_steps=self.config.plan_max_steps,
episode_duration=f"{win_len:.3f}",
observation_block=observation_block,
)
spans = self._vlm_field(self._video_message(record, prompt, window=window), "subtasks")
# Window-relative clamp; no frame-snap dedupe yet (done on the
# merged absolute set).
cleaned = self._clean_spans(spans, record, bounds=(0.0, win_len), dedupe=False)
if not cleaned:
return []
if getattr(self.config, "subtask_verify", False):
cleaned = self._verify_subtasks(record, task, cleaned, window=window)
if not cleaned:
return []
# Offset window-relative spans back to absolute episode time.
for s in cleaned:
s["start"] = w0 + float(s["start"])
s["end"] = w0 + float(s["end"])
return cleaned
def _stitch_full_coverage(
self, spans: list[dict[str, Any]], record: EpisodeRecord
) -> list[dict[str, Any]]:
@@ -619,13 +742,28 @@ class PlanSubtasksMemoryModule:
return spans
def _clean_spans(
self, spans: Any, record: EpisodeRecord
self,
spans: Any,
record: EpisodeRecord,
bounds: tuple[float, float] | None = None,
dedupe: bool = True,
) -> list[dict[str, Any]]:
"""Clamp / sort / dedupe raw VLM subtask spans into valid rows."""
"""Clamp / sort / (optionally) dedupe raw VLM subtask spans into valid rows.
``bounds`` overrides the clamp range — pass the window's
``(w_lo, w_hi)`` when cleaning window-relative spans, or leave
``None`` to clamp to the whole episode ``[t0, t_last]``.
``dedupe`` runs the frame-snap distinct-start step; skip it for
window-relative spans (frame snapping is done once on the merged,
absolute-time set).
"""
if not spans:
return []
t0 = record.frame_timestamps[0]
t_last = record.frame_timestamps[-1]
if bounds is not None:
lo, hi = float(bounds[0]), float(bounds[1])
else:
lo = record.frame_timestamps[0]
hi = record.frame_timestamps[-1]
cleaned: list[dict[str, Any]] = []
for span in spans:
try:
@@ -634,20 +772,24 @@ class PlanSubtasksMemoryModule:
text = str(span["text"]).strip()
except (KeyError, ValueError, TypeError):
continue
start = max(t0, min(start, t_last))
end = max(t0, min(end, t_last))
start = max(lo, min(start, hi))
end = max(lo, min(end, hi))
if end < start:
start, end = end, start
if not text:
continue
cleaned.append({"text": text, "start": start, "end": end})
cleaned.sort(key=lambda s: s["start"])
return self._dedupe_starts_to_distinct_frames(cleaned, record)
if dedupe:
return self._dedupe_starts_to_distinct_frames(cleaned, record)
return cleaned
def _describe_episode(self, record: EpisodeRecord, task: str) -> str:
"""Grounding pass: free-form chronological description of the video."""
def _describe_episode(
self, record: EpisodeRecord, task: str, window: tuple[float, float] | None = None
) -> str:
"""Grounding pass: free-form chronological description of the (windowed) video."""
prompt = load_prompt("module_1_subtask_describe").format(episode_task=task)
text = self._vlm_field(self._video_message(record, prompt), "description")
text = self._vlm_field(self._video_message(record, prompt, window=window), "description")
return text.strip() if isinstance(text, str) and text.strip() else ""
def _verify_subtasks(
@@ -655,6 +797,7 @@ class PlanSubtasksMemoryModule:
record: EpisodeRecord,
task: str,
spans: list[dict[str, Any]],
window: tuple[float, float] | None = None,
) -> list[dict[str, Any]]:
"""Adversarial pass: drop proposed subtasks not visible in the video.
@@ -674,8 +817,16 @@ class PlanSubtasksMemoryModule:
prompt = load_prompt("module_1_subtask_verify").format(
episode_task=task, subtasks_json=subtasks_json
)
kept_raw = self._vlm_field(self._video_message(record, prompt), "subtasks")
kept = self._clean_spans(kept_raw, record)
kept_raw = self._vlm_field(self._video_message(record, prompt, window=window), "subtasks")
# Windowed verify: the video is sampled from the absolute window
# ``[w0, w1]`` but the model perceives it as a clip starting at 0,
# so proposed + returned times are window-RELATIVE in ``[0, L]``.
# Clamp to that relative range and skip the absolute frame-snap
# dedupe (done once later on the merged absolute-time set).
clamp = (0.0, float(window[1] - window[0])) if window is not None else None
kept = self._clean_spans(
kept_raw, record, bounds=clamp, dedupe=window is None
)
if not kept:
logger.info(
"episode %d: verify pass returned nothing — keeping the %d "