diff --git a/examples/annotations/run_hf_job.py b/examples/annotations/run_hf_job.py index cbae22796..86575f72f 100644 --- a/examples/annotations/run_hf_job.py +++ b/examples/annotations/run_hf_job.py @@ -70,6 +70,12 @@ CMD = ( "--plan.use_video_url=false " "--plan.frames_per_second=1.0 " "--plan.max_video_frames=32 " + # Constant 1 fps density via windowing: episodes longer than 32s are + # split into 32-second windows (each 32 frames @ 1 fps, fits context), + # so long episodes get MORE subtasks instead of a sparser whole-episode + # view. describe->segment->verify runs per window; spans are merged + + # stitched to a contiguous whole-episode cover. 0 disables. + "--plan.subtask_window_seconds=32 " # IMPORTANT for RoboCasa: the dataset's task string ("Navigate to the # stove", "Pick the mug...") is authoritative and is what eval uses. # ``derive_task_from_video=off`` keeps that canonical task driving diff --git a/src/lerobot/annotations/steerable_pipeline/config.py b/src/lerobot/annotations/steerable_pipeline/config.py index 37371a7fb..414824cfb 100644 --- a/src/lerobot/annotations/steerable_pipeline/config.py +++ b/src/lerobot/annotations/steerable_pipeline/config.py @@ -58,6 +58,19 @@ class PlanConfig: frames_per_second: float = 1.0 max_video_frames: int = 32 + # Windowed subtask generation for CONSTANT temporal density. When > 0 + # and an episode is longer than this many seconds, the plan module + # processes the episode in consecutive windows of this length, each + # sampled at ``frames_per_second``, instead of subsampling the whole + # episode to ``max_video_frames`` (which makes long episodes sparse). + # The describe -> segment -> verify chain runs per window; results are + # offset to absolute time, merged, and stitched into a contiguous + # whole-episode cover. Cost scales with episode length (≈ chain calls + # × ceil(duration / window)). Set to ~max_video_frames / frames_per_ + # second (e.g. 32s at 1 fps) so each window fills — but never exceeds — + # the per-call frame budget. 0 disables (single whole-episode call). + subtask_window_seconds: float = 0.0 + min_subtask_seconds: float = 1.5 plan_max_steps: int = 8 diff --git a/src/lerobot/annotations/steerable_pipeline/modules/plan_subtasks_memory.py b/src/lerobot/annotations/steerable_pipeline/modules/plan_subtasks_memory.py index 4ffef49c1..991ee3a3b 100644 --- a/src/lerobot/annotations/steerable_pipeline/modules/plan_subtasks_memory.py +++ b/src/lerobot/annotations/steerable_pipeline/modules/plan_subtasks_memory.py @@ -272,9 +272,14 @@ class PlanSubtasksMemoryModule: """One-shot text-only user message wrapped for ``generate_json``.""" return [{"role": "user", "content": [{"type": "text", "text": text}]}] - def _video_message(self, record: EpisodeRecord, prompt: str) -> list[dict[str, Any]]: - """User message combining the episode video block with ``prompt``.""" - content = [*self._episode_video_block(record), {"type": "text", "text": prompt}] + def _video_message( + self, + record: EpisodeRecord, + prompt: str, + window: tuple[float, float] | None = None, + ) -> list[dict[str, Any]]: + """User message combining the (optionally windowed) video block with ``prompt``.""" + content = [*self._episode_video_block(record, window=window), {"type": "text", "text": prompt}] return [{"role": "user", "content": content}] def _derive_task_from_video(self, record: EpisodeRecord) -> str | None: @@ -442,8 +447,10 @@ class PlanSubtasksMemoryModule: flat.append(key) return flat - def _episode_video_block(self, record: EpisodeRecord) -> list[dict[str, Any]]: - """Same video block ``_generate_subtasks`` builds — extracted helper. + def _episode_video_block( + self, record: EpisodeRecord, window: tuple[float, float] | None = None + ) -> list[dict[str, Any]]: + """Video block for the segmentation / describe / verify prompts. Always returns a block that actually carries the video. When ``use_video_url`` is set we try the server-side ``video_url`` @@ -452,9 +459,29 @@ class PlanSubtasksMemoryModule: block — an empty block would leave the VLM with no visual grounding at all and it would hallucinate subtasks purely from the task text. + + When ``window=(w0, w1)`` is given (windowed subtask generation, + ``subtask_window_seconds > 0``), embed frames sampled at the FIXED + ``frames_per_second`` rate within ``[w0, w1]`` — constant temporal + density regardless of episode length, so long episodes are split + into windows rather than subsampled to a sparse 32-frame whole- + episode view. The ``video_url`` path is skipped for windows (it is + a whole-episode clip). ``max_video_frames`` still caps each window + as a context-budget safety net. """ if not record.frame_timestamps: return [] + if window is not None: + w0, w1 = float(window[0]), float(window[1]) + dur = max(0.0, w1 - w0) + n = max(1, int(round(dur * self.config.frames_per_second)) + 1) + n = min(n, self.config.max_video_frames) + if n <= 1 or dur <= 0.0: + timestamps = [0.5 * (w0 + w1)] + else: + step = dur / (n - 1) + timestamps = [w0 + i * step for i in range(n)] + return to_video_block(self.frame_provider.frames_at(record, timestamps)) if self.config.use_video_url and isinstance(self.frame_provider, VideoFrameProvider): cache_dir = Path(self.frame_provider.root) / ".annotate_staging" / ".video_clips" clip = self.frame_provider.episode_clip_path(record, cache_dir) @@ -541,6 +568,17 @@ class PlanSubtasksMemoryModule: episode_duration = record.frame_timestamps[-1] - record.frame_timestamps[0] effective_task = task if task is not None else record.episode_task + # ---- Windowed path (constant temporal density) --------------- + # When ``subtask_window_seconds > 0`` and the episode is longer + # than one window, process the episode in fixed-length windows so + # the VLM always sees ``frames_per_second`` density (instead of a + # sparse 32-frame whole-episode view). Each window runs the full + # describe -> segment -> verify chain on its own frames; results + # are merged + stitched into a contiguous whole-episode cover. + window_s = float(getattr(self.config, "subtask_window_seconds", 0.0) or 0.0) + if window_s > 0.0 and episode_duration > window_s: + return self._generate_subtasks_windowed(record, effective_task, window_s) + # ---- Pass 1 (optional): grounding description ---------------- observation_block = "" if getattr(self.config, "subtask_describe_first", False): @@ -586,6 +624,91 @@ class PlanSubtasksMemoryModule: return cleaned + def _generate_subtasks_windowed( + self, record: EpisodeRecord, task: str, window_s: float + ) -> list[dict[str, Any]]: + """Subtask generation in fixed-length windows at constant fps. + + Splits ``[t0, t_last]`` into consecutive windows of ``window_s`` + seconds, runs the describe -> segment -> verify chain on each + window's own frames (sampled at ``frames_per_second``), offsets + each window's spans back to absolute episode time, then merges + + stitches into a contiguous whole-episode cover. + """ + t0 = float(record.frame_timestamps[0]) + t_last = float(record.frame_timestamps[-1]) + all_spans: list[dict[str, Any]] = [] + w0 = t0 + n_windows = 0 + while w0 < t_last - 1e-6: + w1 = min(w0 + window_s, t_last) + all_spans.extend(self._subtasks_for_window(record, task, w0, w1)) + n_windows += 1 + w0 = w1 + logger.info( + "episode %d: windowed subtask gen over %d window(s) of %.1fs -> %d raw spans", + record.episode_index, + n_windows, + window_s, + len(all_spans), + ) + # Merge across windows: clamp to the absolute episode, sort, and + # frame-snap to distinct starts (handles any boundary collisions). + cleaned = self._clean_spans(all_spans, record) + if not cleaned: + return [] + return self._stitch_full_coverage(cleaned, record) + + def _subtasks_for_window( + self, record: EpisodeRecord, task: str, w0: float, w1: float + ) -> list[dict[str, Any]]: + """Run describe -> segment -> verify on one ``[w0, w1]`` window. + + The model works in window-RELATIVE time ``[0, L]`` (it perceives + the window as a clip starting at 0); spans are offset back to + absolute ``[w0, w1]`` before returning. + """ + window = (w0, w1) + win_len = max(0.0, w1 - w0) + + observation_block = "" + if getattr(self.config, "subtask_describe_first", False): + description = self._describe_episode(record, task, window=window) + if description: + observation_block = ( + "You watched this video clip and described, chronologically, " + "ONLY what the robot actually does:\n" + f'"""{description}"""\n\n' + "Segment THAT grounded description (cross-checked against " + "the clip) into atomic subtasks. Do not introduce any " + "action that is not in your description above.\n\n" + ) + + prompt = load_prompt("module_1_subtasks").format( + episode_task=task, + min_subtask_seconds=self.config.min_subtask_seconds, + max_steps=self.config.plan_max_steps, + episode_duration=f"{win_len:.3f}", + observation_block=observation_block, + ) + spans = self._vlm_field(self._video_message(record, prompt, window=window), "subtasks") + # Window-relative clamp; no frame-snap dedupe yet (done on the + # merged absolute set). + cleaned = self._clean_spans(spans, record, bounds=(0.0, win_len), dedupe=False) + if not cleaned: + return [] + + if getattr(self.config, "subtask_verify", False): + cleaned = self._verify_subtasks(record, task, cleaned, window=window) + if not cleaned: + return [] + + # Offset window-relative spans back to absolute episode time. + for s in cleaned: + s["start"] = w0 + float(s["start"]) + s["end"] = w0 + float(s["end"]) + return cleaned + def _stitch_full_coverage( self, spans: list[dict[str, Any]], record: EpisodeRecord ) -> list[dict[str, Any]]: @@ -619,13 +742,28 @@ class PlanSubtasksMemoryModule: return spans def _clean_spans( - self, spans: Any, record: EpisodeRecord + self, + spans: Any, + record: EpisodeRecord, + bounds: tuple[float, float] | None = None, + dedupe: bool = True, ) -> list[dict[str, Any]]: - """Clamp / sort / dedupe raw VLM subtask spans into valid rows.""" + """Clamp / sort / (optionally) dedupe raw VLM subtask spans into valid rows. + + ``bounds`` overrides the clamp range — pass the window's + ``(w_lo, w_hi)`` when cleaning window-relative spans, or leave + ``None`` to clamp to the whole episode ``[t0, t_last]``. + ``dedupe`` runs the frame-snap distinct-start step; skip it for + window-relative spans (frame snapping is done once on the merged, + absolute-time set). + """ if not spans: return [] - t0 = record.frame_timestamps[0] - t_last = record.frame_timestamps[-1] + if bounds is not None: + lo, hi = float(bounds[0]), float(bounds[1]) + else: + lo = record.frame_timestamps[0] + hi = record.frame_timestamps[-1] cleaned: list[dict[str, Any]] = [] for span in spans: try: @@ -634,20 +772,24 @@ class PlanSubtasksMemoryModule: text = str(span["text"]).strip() except (KeyError, ValueError, TypeError): continue - start = max(t0, min(start, t_last)) - end = max(t0, min(end, t_last)) + start = max(lo, min(start, hi)) + end = max(lo, min(end, hi)) if end < start: start, end = end, start if not text: continue cleaned.append({"text": text, "start": start, "end": end}) cleaned.sort(key=lambda s: s["start"]) - return self._dedupe_starts_to_distinct_frames(cleaned, record) + if dedupe: + return self._dedupe_starts_to_distinct_frames(cleaned, record) + return cleaned - def _describe_episode(self, record: EpisodeRecord, task: str) -> str: - """Grounding pass: free-form chronological description of the video.""" + def _describe_episode( + self, record: EpisodeRecord, task: str, window: tuple[float, float] | None = None + ) -> str: + """Grounding pass: free-form chronological description of the (windowed) video.""" prompt = load_prompt("module_1_subtask_describe").format(episode_task=task) - text = self._vlm_field(self._video_message(record, prompt), "description") + text = self._vlm_field(self._video_message(record, prompt, window=window), "description") return text.strip() if isinstance(text, str) and text.strip() else "" def _verify_subtasks( @@ -655,6 +797,7 @@ class PlanSubtasksMemoryModule: record: EpisodeRecord, task: str, spans: list[dict[str, Any]], + window: tuple[float, float] | None = None, ) -> list[dict[str, Any]]: """Adversarial pass: drop proposed subtasks not visible in the video. @@ -674,8 +817,16 @@ class PlanSubtasksMemoryModule: prompt = load_prompt("module_1_subtask_verify").format( episode_task=task, subtasks_json=subtasks_json ) - kept_raw = self._vlm_field(self._video_message(record, prompt), "subtasks") - kept = self._clean_spans(kept_raw, record) + kept_raw = self._vlm_field(self._video_message(record, prompt, window=window), "subtasks") + # Windowed verify: the video is sampled from the absolute window + # ``[w0, w1]`` but the model perceives it as a clip starting at 0, + # so proposed + returned times are window-RELATIVE in ``[0, L]``. + # Clamp to that relative range and skip the absolute frame-snap + # dedupe (done once later on the merged absolute-time set). + clamp = (0.0, float(window[1] - window[0])) if window is not None else None + kept = self._clean_spans( + kept_raw, record, bounds=clamp, dedupe=window is None + ) if not kept: logger.info( "episode %d: verify pass returned nothing — keeping the %d "