mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 00:37:10 +00:00
annotate: windowed subtask generation for constant temporal density
Long episodes no longer get sparse subtasks. Previously a long episode
was subsampled to max_video_frames=32 across its whole duration (~1
frame/4s for a 2-min clip). New opt-in windowing keeps a CONSTANT
frames_per_second density by splitting the episode into fixed-length
windows and running the subtask chain per window.
New PlanConfig.subtask_window_seconds (default 0.0 = off). When > 0 and
the episode is longer than one window:
* episode is split into consecutive [w0, w1] windows of this length
* each window's frames are sampled at frames_per_second (so a 32s
window at 1 fps = 32 frames, filling but not exceeding the per-call
context budget)
* the full describe -> segment -> verify chain runs PER window, in
window-relative time [0, L]; spans are offset back to absolute
* all windows' spans are merged, frame-snap-deduped, and stitched into
one contiguous whole-episode cover
Implementation:
* _episode_video_block / _video_message / _describe_episode /
_verify_subtasks gain an optional window=(w0,w1); when set they
embed frames sampled in that absolute range at frames_per_second
(video_url path skipped — it's whole-episode).
* _clean_spans gains bounds= (override clamp range, for window-relative
spans) and dedupe= (skip frame-snap until the merged absolute set).
* new _generate_subtasks_windowed + _subtasks_for_window orchestrate
the loop; _generate_subtasks branches to them when window_s > 0.
run_hf_job.py: --plan.subtask_window_seconds=32 (32s windows at 1 fps).
Cost scales with episode length (chain calls × ceil(duration/window)).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -70,6 +70,12 @@ CMD = (
|
||||
"--plan.use_video_url=false "
|
||||
"--plan.frames_per_second=1.0 "
|
||||
"--plan.max_video_frames=32 "
|
||||
# Constant 1 fps density via windowing: episodes longer than 32s are
|
||||
# split into 32-second windows (each 32 frames @ 1 fps, fits context),
|
||||
# so long episodes get MORE subtasks instead of a sparser whole-episode
|
||||
# view. describe->segment->verify runs per window; spans are merged +
|
||||
# stitched to a contiguous whole-episode cover. 0 disables.
|
||||
"--plan.subtask_window_seconds=32 "
|
||||
# IMPORTANT for RoboCasa: the dataset's task string ("Navigate to the
|
||||
# stove", "Pick the mug...") is authoritative and is what eval uses.
|
||||
# ``derive_task_from_video=off`` keeps that canonical task driving
|
||||
|
||||
@@ -58,6 +58,19 @@ class PlanConfig:
|
||||
frames_per_second: float = 1.0
|
||||
max_video_frames: int = 32
|
||||
|
||||
# Windowed subtask generation for CONSTANT temporal density. When > 0
|
||||
# and an episode is longer than this many seconds, the plan module
|
||||
# processes the episode in consecutive windows of this length, each
|
||||
# sampled at ``frames_per_second``, instead of subsampling the whole
|
||||
# episode to ``max_video_frames`` (which makes long episodes sparse).
|
||||
# The describe -> segment -> verify chain runs per window; results are
|
||||
# offset to absolute time, merged, and stitched into a contiguous
|
||||
# whole-episode cover. Cost scales with episode length (≈ chain calls
|
||||
# × ceil(duration / window)). Set to ~max_video_frames / frames_per_
|
||||
# second (e.g. 32s at 1 fps) so each window fills — but never exceeds —
|
||||
# the per-call frame budget. 0 disables (single whole-episode call).
|
||||
subtask_window_seconds: float = 0.0
|
||||
|
||||
min_subtask_seconds: float = 1.5
|
||||
plan_max_steps: int = 8
|
||||
|
||||
|
||||
@@ -272,9 +272,14 @@ class PlanSubtasksMemoryModule:
|
||||
"""One-shot text-only user message wrapped for ``generate_json``."""
|
||||
return [{"role": "user", "content": [{"type": "text", "text": text}]}]
|
||||
|
||||
def _video_message(self, record: EpisodeRecord, prompt: str) -> list[dict[str, Any]]:
|
||||
"""User message combining the episode video block with ``prompt``."""
|
||||
content = [*self._episode_video_block(record), {"type": "text", "text": prompt}]
|
||||
def _video_message(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
prompt: str,
|
||||
window: tuple[float, float] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""User message combining the (optionally windowed) video block with ``prompt``."""
|
||||
content = [*self._episode_video_block(record, window=window), {"type": "text", "text": prompt}]
|
||||
return [{"role": "user", "content": content}]
|
||||
|
||||
def _derive_task_from_video(self, record: EpisodeRecord) -> str | None:
|
||||
@@ -442,8 +447,10 @@ class PlanSubtasksMemoryModule:
|
||||
flat.append(key)
|
||||
return flat
|
||||
|
||||
def _episode_video_block(self, record: EpisodeRecord) -> list[dict[str, Any]]:
|
||||
"""Same video block ``_generate_subtasks`` builds — extracted helper.
|
||||
def _episode_video_block(
|
||||
self, record: EpisodeRecord, window: tuple[float, float] | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Video block for the segmentation / describe / verify prompts.
|
||||
|
||||
Always returns a block that actually carries the video. When
|
||||
``use_video_url`` is set we try the server-side ``video_url``
|
||||
@@ -452,9 +459,29 @@ class PlanSubtasksMemoryModule:
|
||||
block — an empty block would leave the VLM with no visual
|
||||
grounding at all and it would hallucinate subtasks purely from
|
||||
the task text.
|
||||
|
||||
When ``window=(w0, w1)`` is given (windowed subtask generation,
|
||||
``subtask_window_seconds > 0``), embed frames sampled at the FIXED
|
||||
``frames_per_second`` rate within ``[w0, w1]`` — constant temporal
|
||||
density regardless of episode length, so long episodes are split
|
||||
into windows rather than subsampled to a sparse 32-frame whole-
|
||||
episode view. The ``video_url`` path is skipped for windows (it is
|
||||
a whole-episode clip). ``max_video_frames`` still caps each window
|
||||
as a context-budget safety net.
|
||||
"""
|
||||
if not record.frame_timestamps:
|
||||
return []
|
||||
if window is not None:
|
||||
w0, w1 = float(window[0]), float(window[1])
|
||||
dur = max(0.0, w1 - w0)
|
||||
n = max(1, int(round(dur * self.config.frames_per_second)) + 1)
|
||||
n = min(n, self.config.max_video_frames)
|
||||
if n <= 1 or dur <= 0.0:
|
||||
timestamps = [0.5 * (w0 + w1)]
|
||||
else:
|
||||
step = dur / (n - 1)
|
||||
timestamps = [w0 + i * step for i in range(n)]
|
||||
return to_video_block(self.frame_provider.frames_at(record, timestamps))
|
||||
if self.config.use_video_url and isinstance(self.frame_provider, VideoFrameProvider):
|
||||
cache_dir = Path(self.frame_provider.root) / ".annotate_staging" / ".video_clips"
|
||||
clip = self.frame_provider.episode_clip_path(record, cache_dir)
|
||||
@@ -541,6 +568,17 @@ class PlanSubtasksMemoryModule:
|
||||
episode_duration = record.frame_timestamps[-1] - record.frame_timestamps[0]
|
||||
effective_task = task if task is not None else record.episode_task
|
||||
|
||||
# ---- Windowed path (constant temporal density) ---------------
|
||||
# When ``subtask_window_seconds > 0`` and the episode is longer
|
||||
# than one window, process the episode in fixed-length windows so
|
||||
# the VLM always sees ``frames_per_second`` density (instead of a
|
||||
# sparse 32-frame whole-episode view). Each window runs the full
|
||||
# describe -> segment -> verify chain on its own frames; results
|
||||
# are merged + stitched into a contiguous whole-episode cover.
|
||||
window_s = float(getattr(self.config, "subtask_window_seconds", 0.0) or 0.0)
|
||||
if window_s > 0.0 and episode_duration > window_s:
|
||||
return self._generate_subtasks_windowed(record, effective_task, window_s)
|
||||
|
||||
# ---- Pass 1 (optional): grounding description ----------------
|
||||
observation_block = ""
|
||||
if getattr(self.config, "subtask_describe_first", False):
|
||||
@@ -586,6 +624,91 @@ class PlanSubtasksMemoryModule:
|
||||
|
||||
return cleaned
|
||||
|
||||
def _generate_subtasks_windowed(
|
||||
self, record: EpisodeRecord, task: str, window_s: float
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Subtask generation in fixed-length windows at constant fps.
|
||||
|
||||
Splits ``[t0, t_last]`` into consecutive windows of ``window_s``
|
||||
seconds, runs the describe -> segment -> verify chain on each
|
||||
window's own frames (sampled at ``frames_per_second``), offsets
|
||||
each window's spans back to absolute episode time, then merges +
|
||||
stitches into a contiguous whole-episode cover.
|
||||
"""
|
||||
t0 = float(record.frame_timestamps[0])
|
||||
t_last = float(record.frame_timestamps[-1])
|
||||
all_spans: list[dict[str, Any]] = []
|
||||
w0 = t0
|
||||
n_windows = 0
|
||||
while w0 < t_last - 1e-6:
|
||||
w1 = min(w0 + window_s, t_last)
|
||||
all_spans.extend(self._subtasks_for_window(record, task, w0, w1))
|
||||
n_windows += 1
|
||||
w0 = w1
|
||||
logger.info(
|
||||
"episode %d: windowed subtask gen over %d window(s) of %.1fs -> %d raw spans",
|
||||
record.episode_index,
|
||||
n_windows,
|
||||
window_s,
|
||||
len(all_spans),
|
||||
)
|
||||
# Merge across windows: clamp to the absolute episode, sort, and
|
||||
# frame-snap to distinct starts (handles any boundary collisions).
|
||||
cleaned = self._clean_spans(all_spans, record)
|
||||
if not cleaned:
|
||||
return []
|
||||
return self._stitch_full_coverage(cleaned, record)
|
||||
|
||||
def _subtasks_for_window(
|
||||
self, record: EpisodeRecord, task: str, w0: float, w1: float
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Run describe -> segment -> verify on one ``[w0, w1]`` window.
|
||||
|
||||
The model works in window-RELATIVE time ``[0, L]`` (it perceives
|
||||
the window as a clip starting at 0); spans are offset back to
|
||||
absolute ``[w0, w1]`` before returning.
|
||||
"""
|
||||
window = (w0, w1)
|
||||
win_len = max(0.0, w1 - w0)
|
||||
|
||||
observation_block = ""
|
||||
if getattr(self.config, "subtask_describe_first", False):
|
||||
description = self._describe_episode(record, task, window=window)
|
||||
if description:
|
||||
observation_block = (
|
||||
"You watched this video clip and described, chronologically, "
|
||||
"ONLY what the robot actually does:\n"
|
||||
f'"""{description}"""\n\n'
|
||||
"Segment THAT grounded description (cross-checked against "
|
||||
"the clip) into atomic subtasks. Do not introduce any "
|
||||
"action that is not in your description above.\n\n"
|
||||
)
|
||||
|
||||
prompt = load_prompt("module_1_subtasks").format(
|
||||
episode_task=task,
|
||||
min_subtask_seconds=self.config.min_subtask_seconds,
|
||||
max_steps=self.config.plan_max_steps,
|
||||
episode_duration=f"{win_len:.3f}",
|
||||
observation_block=observation_block,
|
||||
)
|
||||
spans = self._vlm_field(self._video_message(record, prompt, window=window), "subtasks")
|
||||
# Window-relative clamp; no frame-snap dedupe yet (done on the
|
||||
# merged absolute set).
|
||||
cleaned = self._clean_spans(spans, record, bounds=(0.0, win_len), dedupe=False)
|
||||
if not cleaned:
|
||||
return []
|
||||
|
||||
if getattr(self.config, "subtask_verify", False):
|
||||
cleaned = self._verify_subtasks(record, task, cleaned, window=window)
|
||||
if not cleaned:
|
||||
return []
|
||||
|
||||
# Offset window-relative spans back to absolute episode time.
|
||||
for s in cleaned:
|
||||
s["start"] = w0 + float(s["start"])
|
||||
s["end"] = w0 + float(s["end"])
|
||||
return cleaned
|
||||
|
||||
def _stitch_full_coverage(
|
||||
self, spans: list[dict[str, Any]], record: EpisodeRecord
|
||||
) -> list[dict[str, Any]]:
|
||||
@@ -619,13 +742,28 @@ class PlanSubtasksMemoryModule:
|
||||
return spans
|
||||
|
||||
def _clean_spans(
|
||||
self, spans: Any, record: EpisodeRecord
|
||||
self,
|
||||
spans: Any,
|
||||
record: EpisodeRecord,
|
||||
bounds: tuple[float, float] | None = None,
|
||||
dedupe: bool = True,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Clamp / sort / dedupe raw VLM subtask spans into valid rows."""
|
||||
"""Clamp / sort / (optionally) dedupe raw VLM subtask spans into valid rows.
|
||||
|
||||
``bounds`` overrides the clamp range — pass the window's
|
||||
``(w_lo, w_hi)`` when cleaning window-relative spans, or leave
|
||||
``None`` to clamp to the whole episode ``[t0, t_last]``.
|
||||
``dedupe`` runs the frame-snap distinct-start step; skip it for
|
||||
window-relative spans (frame snapping is done once on the merged,
|
||||
absolute-time set).
|
||||
"""
|
||||
if not spans:
|
||||
return []
|
||||
t0 = record.frame_timestamps[0]
|
||||
t_last = record.frame_timestamps[-1]
|
||||
if bounds is not None:
|
||||
lo, hi = float(bounds[0]), float(bounds[1])
|
||||
else:
|
||||
lo = record.frame_timestamps[0]
|
||||
hi = record.frame_timestamps[-1]
|
||||
cleaned: list[dict[str, Any]] = []
|
||||
for span in spans:
|
||||
try:
|
||||
@@ -634,20 +772,24 @@ class PlanSubtasksMemoryModule:
|
||||
text = str(span["text"]).strip()
|
||||
except (KeyError, ValueError, TypeError):
|
||||
continue
|
||||
start = max(t0, min(start, t_last))
|
||||
end = max(t0, min(end, t_last))
|
||||
start = max(lo, min(start, hi))
|
||||
end = max(lo, min(end, hi))
|
||||
if end < start:
|
||||
start, end = end, start
|
||||
if not text:
|
||||
continue
|
||||
cleaned.append({"text": text, "start": start, "end": end})
|
||||
cleaned.sort(key=lambda s: s["start"])
|
||||
return self._dedupe_starts_to_distinct_frames(cleaned, record)
|
||||
if dedupe:
|
||||
return self._dedupe_starts_to_distinct_frames(cleaned, record)
|
||||
return cleaned
|
||||
|
||||
def _describe_episode(self, record: EpisodeRecord, task: str) -> str:
|
||||
"""Grounding pass: free-form chronological description of the video."""
|
||||
def _describe_episode(
|
||||
self, record: EpisodeRecord, task: str, window: tuple[float, float] | None = None
|
||||
) -> str:
|
||||
"""Grounding pass: free-form chronological description of the (windowed) video."""
|
||||
prompt = load_prompt("module_1_subtask_describe").format(episode_task=task)
|
||||
text = self._vlm_field(self._video_message(record, prompt), "description")
|
||||
text = self._vlm_field(self._video_message(record, prompt, window=window), "description")
|
||||
return text.strip() if isinstance(text, str) and text.strip() else ""
|
||||
|
||||
def _verify_subtasks(
|
||||
@@ -655,6 +797,7 @@ class PlanSubtasksMemoryModule:
|
||||
record: EpisodeRecord,
|
||||
task: str,
|
||||
spans: list[dict[str, Any]],
|
||||
window: tuple[float, float] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Adversarial pass: drop proposed subtasks not visible in the video.
|
||||
|
||||
@@ -674,8 +817,16 @@ class PlanSubtasksMemoryModule:
|
||||
prompt = load_prompt("module_1_subtask_verify").format(
|
||||
episode_task=task, subtasks_json=subtasks_json
|
||||
)
|
||||
kept_raw = self._vlm_field(self._video_message(record, prompt), "subtasks")
|
||||
kept = self._clean_spans(kept_raw, record)
|
||||
kept_raw = self._vlm_field(self._video_message(record, prompt, window=window), "subtasks")
|
||||
# Windowed verify: the video is sampled from the absolute window
|
||||
# ``[w0, w1]`` but the model perceives it as a clip starting at 0,
|
||||
# so proposed + returned times are window-RELATIVE in ``[0, L]``.
|
||||
# Clamp to that relative range and skip the absolute frame-snap
|
||||
# dedupe (done once later on the merged absolute-time set).
|
||||
clamp = (0.0, float(window[1] - window[0])) if window is not None else None
|
||||
kept = self._clean_spans(
|
||||
kept_raw, record, bounds=clamp, dedupe=window is None
|
||||
)
|
||||
if not kept:
|
||||
logger.info(
|
||||
"episode %d: verify pass returned nothing — keeping the %d "
|
||||
|
||||
Reference in New Issue
Block a user