diff --git a/src/lerobot/annotations/steerable_pipeline/modules/plan_subtasks_memory.py b/src/lerobot/annotations/steerable_pipeline/modules/plan_subtasks_memory.py index c48d888fb..cb9290e5a 100644 --- a/src/lerobot/annotations/steerable_pipeline/modules/plan_subtasks_memory.py +++ b/src/lerobot/annotations/steerable_pipeline/modules/plan_subtasks_memory.py @@ -19,9 +19,8 @@ from __future__ import annotations from collections.abc import Sequence from dataclasses import dataclass, field -from typing import Any - from pathlib import Path +from typing import Any from ..config import Module1Config from ..frames import ( @@ -81,9 +80,7 @@ class PlanSubtasksMemoryModule: # so the policy sees diverse phrasings during training. t0 = float(record.frame_timestamps[0]) if record.frame_timestamps else 0.0 if self.config.n_task_rephrasings > 0 and effective_task: - rephrasings = self._generate_task_rephrasings( - effective_task, n=self.config.n_task_rephrasings - ) + rephrasings = self._generate_task_rephrasings(effective_task, n=self.config.n_task_rephrasings) # Always include the effective task itself as the first variant # so the rotation is guaranteed to cover the source-of-truth # phrasing, not just synthetic alternatives. @@ -133,9 +130,7 @@ class PlanSubtasksMemoryModule: for i, span in enumerate(subtask_spans[1:], start=1): completed = subtask_spans[i - 1]["text"] remaining = [s["text"] for s in subtask_spans[i:]] - mem_text = self._generate_memory( - record, prior_memory, completed, remaining, task=effective_task - ) + mem_text = self._generate_memory(record, prior_memory, completed, remaining, task=effective_task) if mem_text: ts = _snap_to_frame(span["start"], record.frame_timestamps) rows.append( @@ -193,44 +188,50 @@ class PlanSubtasksMemoryModule: return True if len(task.split()) < int(self.config.derive_task_min_words): return True - if task.lower() in self._PLACEHOLDER_TASKS: - return True - return False + return task.lower() in self._PLACEHOLDER_TASKS + + # ------------------------------------------------------------------ + # VLM call helpers (factored out: every Module-1 prompt below follows + # the same "build messages → single VLM call → pull a named field" + # shape, only differing in field name + post-processing). + # ------------------------------------------------------------------ + + def _vlm_field(self, messages: list[dict[str, Any]], field: str) -> Any: + """Run a single VLM call and return ``result[field]`` or ``None``. + + Centralizes the ``vlm.generate_json([m])[0]`` + ``isinstance(dict)`` + dance every prompt-call site needs. + """ + result = self.vlm.generate_json([messages])[0] + if isinstance(result, dict): + return result.get(field) + return None + + @staticmethod + def _text_message(text: str) -> list[dict[str, Any]]: + """One-shot text-only user message wrapped for ``generate_json``.""" + return [{"role": "user", "content": [{"type": "text", "text": text}]}] + + def _video_message(self, record: EpisodeRecord, prompt: str) -> list[dict[str, Any]]: + """User message combining the episode video block with ``prompt``.""" + content = [*self._episode_video_block(record), {"type": "text", "text": prompt}] + return [{"role": "user", "content": content}] def _derive_task_from_video(self, record: EpisodeRecord) -> str | None: """Ask the VLM "what is this video about" with no task hint at all.""" - prompt = load_prompt("module_1_video_task") - video_block = self._episode_video_block(record) - content = [*video_block, {"type": "text", "text": prompt}] - messages = [{"role": "user", "content": content}] - result = self.vlm.generate_json([messages])[0] - if isinstance(result, dict) and isinstance(result.get("task"), str): - text = result["task"].strip() - if text: - return text - return None + text = self._vlm_field(self._video_message(record, load_prompt("module_1_video_task")), "task") + return text.strip() if isinstance(text, str) and text.strip() else None def _generate_task_rephrasings(self, base_task: str, *, n: int) -> list[str]: """Generate ``n`` text-only paraphrases of ``base_task``.""" if n <= 0 or not base_task: return [] - prompt = load_prompt("module_1_task_rephrasings").format( - base_task=base_task, n=n - ) - messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}] - result = self.vlm.generate_json([messages])[0] - if not isinstance(result, dict): - return [] - raw = result.get("rephrasings") + prompt = load_prompt("module_1_task_rephrasings").format(base_task=base_task, n=n) + raw = self._vlm_field(self._text_message(prompt), "rephrasings") if not isinstance(raw, list): return [] - out: list[str] = [] - for item in raw: - if isinstance(item, str): - cleaned = item.strip().strip('"').strip("'") - if cleaned: - out.append(cleaned) - return out[:n] + out = [item.strip().strip('"').strip("'") for item in raw if isinstance(item, str)] + return [s for s in out if s][:n] def _episode_video_block(self, record: EpisodeRecord) -> list[dict[str, Any]]: """Same video block ``_generate_subtasks`` builds — extracted helper.""" @@ -245,9 +246,7 @@ class PlanSubtasksMemoryModule: else [] ) episode_duration = record.frame_timestamps[-1] - record.frame_timestamps[0] - target_count = max( - 1, int(round(episode_duration * self.config.frames_per_second)) - ) + target_count = max(1, int(round(episode_duration * self.config.frames_per_second))) target_count = min(target_count, self.config.max_video_frames) video_frames = self.frame_provider.video_for_episode(record, target_count) return to_video_block(video_frames) @@ -270,9 +269,7 @@ class PlanSubtasksMemoryModule: """ existing = staging.read("module_1") spans = self._reconstruct_subtasks_from_rows(existing) - already_planned: set[float] = { - float(r["timestamp"]) for r in existing if r.get("style") == "plan" - } + already_planned: set[float] = {float(r["timestamp"]) for r in existing if r.get("style") == "plan"} new_rows = list(existing) texts: list[str | None] = ( @@ -280,14 +277,12 @@ class PlanSubtasksMemoryModule: if interjection_texts is None else [str(t) if t else None for t in interjection_texts] ) - for raw_t, inter_text in zip(interjection_times, texts): + for raw_t, inter_text in zip(interjection_times, texts, strict=True): t = _snap_to_frame(raw_t, record.frame_timestamps) if t in already_planned: continue already_planned.add(t) - plan_text = self._generate_plan( - record, spans, refresh_t=t, interjection=inter_text - ) + plan_text = self._generate_plan(record, spans, refresh_t=t, interjection=inter_text) if plan_text is not None: new_rows.append( { @@ -315,9 +310,7 @@ class PlanSubtasksMemoryModule: last_t = t return out - def _generate_subtasks( - self, record: EpisodeRecord, *, task: str | None = None - ) -> list[dict[str, Any]]: + def _generate_subtasks(self, record: EpisodeRecord, *, task: str | None = None) -> list[dict[str, Any]]: if record.row_count == 0 or not record.frame_timestamps: return [] episode_duration = record.frame_timestamps[-1] - record.frame_timestamps[0] @@ -327,26 +320,7 @@ class PlanSubtasksMemoryModule: max_steps=self.config.plan_max_steps, episode_duration=f"{episode_duration:.3f}", ) - if self.config.use_video_url and isinstance(self.frame_provider, VideoFrameProvider): - cache_dir = Path(self.frame_provider.root) / ".annotate_staging" / ".video_clips" - clip = episode_clip_path(record, self.frame_provider, cache_dir) - video_block = ( - to_video_url_block(f"file://{clip}", fps=self.config.use_video_url_fps) - if clip is not None - else [] - ) - else: - target_count = max( - 1, - int(round(episode_duration * self.config.frames_per_second)), - ) - target_count = min(target_count, self.config.max_video_frames) - video_frames = self.frame_provider.video_for_episode(record, target_count) - video_block = to_video_block(video_frames) - content = [*video_block, {"type": "text", "text": prompt}] - messages = [{"role": "user", "content": content}] - result = self.vlm.generate_json([messages])[0] - spans = result.get("subtasks") if isinstance(result, dict) else None + spans = self._vlm_field(self._video_message(record, prompt), "subtasks") if not spans: return [] # clamp to [t0, t_last] and sort @@ -411,15 +385,9 @@ class PlanSubtasksMemoryModule: # where in the episode the plan stands so the re-emission # is grounded. Should be rare — plan refreshes are # interjection-driven by design. - prompt += ( - f"\n\n(Plan refresh at t={refresh_t:.2f}s. Current " - f"subtask: {current_subtask!r}.)\n" - ) - messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}] - result = self.vlm.generate_json([messages])[0] - if isinstance(result, dict) and isinstance(result.get("plan"), str): - return result["plan"].strip() - return None + prompt += f"\n\n(Plan refresh at t={refresh_t:.2f}s. Current subtask: {current_subtask!r}.)\n" + plan = self._vlm_field(self._text_message(prompt), "plan") + return plan.strip() if isinstance(plan, str) else None def _generate_memory( self, @@ -436,8 +404,5 @@ class PlanSubtasksMemoryModule: completed_subtask=completed, remaining_subtasks=", ".join(remaining) if remaining else "(none)", ) - messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}] - result = self.vlm.generate_json([messages])[0] - if isinstance(result, dict) and isinstance(result.get("memory"), str): - return result["memory"].strip() - return "" + memory = self._vlm_field(self._text_message(prompt), "memory") + return memory.strip() if isinstance(memory, str) else ""