mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
refactor(annotate): consolidate Module 1's prompt → VLM → JSON-extract pattern
Five Module 1 sub-prompts (`_derive_task_from_video`,
`_generate_task_rephrasings`, `_generate_subtasks`, `_generate_plan`,
`_generate_memory`) all repeated the same shape:
result = self.vlm.generate_json([messages])[0]
if isinstance(result, dict) and isinstance(result.get(<field>), <type>):
...
…each spelled with slightly different field names + post-processing.
Three small helpers replace it:
* `_vlm_field(messages, field)` — single VLM call, returns
``result[field]`` or ``None``. Centralizes the
``generate_json([m])[0]`` + ``isinstance(dict)`` dance.
* `_text_message(text)` — wraps a string in the canonical user-message
shape every text-only prompt builds inline.
* `_video_message(record, prompt)` — combines the episode video block
with a prompt; replaces the duplicated video-block construction
inside `_generate_subtasks` (which previously inlined the same
``use_video_url``/``frames_per_second``/``max_video_frames`` branches
that `_episode_video_block` already implements).
Net -35 LOC. Each call site now is 3-5 lines instead of 10-20. The
public method signatures are unchanged so tests don't move.
Drive-by: `_task_seems_bad` collapsed via SIM103 fix; `zip` in
`run_plan_updates` annotated `strict=True` per ruff B905.
Tests: same 2 pre-existing module-impl failures
(`test_module1_attaches_video_block_to_subtask_prompt`,
`test_module2_mid_episode_emits_paired_interjection_and_speech`) —
they were failing on `origin/feat/language-annotation-pipeline` before
this commit and continue to do so for the same reasons. 61/63 in the
language stack pass; pre-commit clean.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -19,9 +19,8 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from ..config import Module1Config
|
||||
from ..frames import (
|
||||
@@ -81,9 +80,7 @@ class PlanSubtasksMemoryModule:
|
||||
# so the policy sees diverse phrasings during training.
|
||||
t0 = float(record.frame_timestamps[0]) if record.frame_timestamps else 0.0
|
||||
if self.config.n_task_rephrasings > 0 and effective_task:
|
||||
rephrasings = self._generate_task_rephrasings(
|
||||
effective_task, n=self.config.n_task_rephrasings
|
||||
)
|
||||
rephrasings = self._generate_task_rephrasings(effective_task, n=self.config.n_task_rephrasings)
|
||||
# Always include the effective task itself as the first variant
|
||||
# so the rotation is guaranteed to cover the source-of-truth
|
||||
# phrasing, not just synthetic alternatives.
|
||||
@@ -133,9 +130,7 @@ class PlanSubtasksMemoryModule:
|
||||
for i, span in enumerate(subtask_spans[1:], start=1):
|
||||
completed = subtask_spans[i - 1]["text"]
|
||||
remaining = [s["text"] for s in subtask_spans[i:]]
|
||||
mem_text = self._generate_memory(
|
||||
record, prior_memory, completed, remaining, task=effective_task
|
||||
)
|
||||
mem_text = self._generate_memory(record, prior_memory, completed, remaining, task=effective_task)
|
||||
if mem_text:
|
||||
ts = _snap_to_frame(span["start"], record.frame_timestamps)
|
||||
rows.append(
|
||||
@@ -193,44 +188,50 @@ class PlanSubtasksMemoryModule:
|
||||
return True
|
||||
if len(task.split()) < int(self.config.derive_task_min_words):
|
||||
return True
|
||||
if task.lower() in self._PLACEHOLDER_TASKS:
|
||||
return True
|
||||
return False
|
||||
return task.lower() in self._PLACEHOLDER_TASKS
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# VLM call helpers (factored out: every Module-1 prompt below follows
|
||||
# the same "build messages → single VLM call → pull a named field"
|
||||
# shape, only differing in field name + post-processing).
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _vlm_field(self, messages: list[dict[str, Any]], field: str) -> Any:
|
||||
"""Run a single VLM call and return ``result[field]`` or ``None``.
|
||||
|
||||
Centralizes the ``vlm.generate_json([m])[0]`` + ``isinstance(dict)``
|
||||
dance every prompt-call site needs.
|
||||
"""
|
||||
result = self.vlm.generate_json([messages])[0]
|
||||
if isinstance(result, dict):
|
||||
return result.get(field)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _text_message(text: str) -> list[dict[str, Any]]:
|
||||
"""One-shot text-only user message wrapped for ``generate_json``."""
|
||||
return [{"role": "user", "content": [{"type": "text", "text": text}]}]
|
||||
|
||||
def _video_message(self, record: EpisodeRecord, prompt: str) -> list[dict[str, Any]]:
|
||||
"""User message combining the episode video block with ``prompt``."""
|
||||
content = [*self._episode_video_block(record), {"type": "text", "text": prompt}]
|
||||
return [{"role": "user", "content": content}]
|
||||
|
||||
def _derive_task_from_video(self, record: EpisodeRecord) -> str | None:
|
||||
"""Ask the VLM "what is this video about" with no task hint at all."""
|
||||
prompt = load_prompt("module_1_video_task")
|
||||
video_block = self._episode_video_block(record)
|
||||
content = [*video_block, {"type": "text", "text": prompt}]
|
||||
messages = [{"role": "user", "content": content}]
|
||||
result = self.vlm.generate_json([messages])[0]
|
||||
if isinstance(result, dict) and isinstance(result.get("task"), str):
|
||||
text = result["task"].strip()
|
||||
if text:
|
||||
return text
|
||||
return None
|
||||
text = self._vlm_field(self._video_message(record, load_prompt("module_1_video_task")), "task")
|
||||
return text.strip() if isinstance(text, str) and text.strip() else None
|
||||
|
||||
def _generate_task_rephrasings(self, base_task: str, *, n: int) -> list[str]:
|
||||
"""Generate ``n`` text-only paraphrases of ``base_task``."""
|
||||
if n <= 0 or not base_task:
|
||||
return []
|
||||
prompt = load_prompt("module_1_task_rephrasings").format(
|
||||
base_task=base_task, n=n
|
||||
)
|
||||
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
|
||||
result = self.vlm.generate_json([messages])[0]
|
||||
if not isinstance(result, dict):
|
||||
return []
|
||||
raw = result.get("rephrasings")
|
||||
prompt = load_prompt("module_1_task_rephrasings").format(base_task=base_task, n=n)
|
||||
raw = self._vlm_field(self._text_message(prompt), "rephrasings")
|
||||
if not isinstance(raw, list):
|
||||
return []
|
||||
out: list[str] = []
|
||||
for item in raw:
|
||||
if isinstance(item, str):
|
||||
cleaned = item.strip().strip('"').strip("'")
|
||||
if cleaned:
|
||||
out.append(cleaned)
|
||||
return out[:n]
|
||||
out = [item.strip().strip('"').strip("'") for item in raw if isinstance(item, str)]
|
||||
return [s for s in out if s][:n]
|
||||
|
||||
def _episode_video_block(self, record: EpisodeRecord) -> list[dict[str, Any]]:
|
||||
"""Same video block ``_generate_subtasks`` builds — extracted helper."""
|
||||
@@ -245,9 +246,7 @@ class PlanSubtasksMemoryModule:
|
||||
else []
|
||||
)
|
||||
episode_duration = record.frame_timestamps[-1] - record.frame_timestamps[0]
|
||||
target_count = max(
|
||||
1, int(round(episode_duration * self.config.frames_per_second))
|
||||
)
|
||||
target_count = max(1, int(round(episode_duration * self.config.frames_per_second)))
|
||||
target_count = min(target_count, self.config.max_video_frames)
|
||||
video_frames = self.frame_provider.video_for_episode(record, target_count)
|
||||
return to_video_block(video_frames)
|
||||
@@ -270,9 +269,7 @@ class PlanSubtasksMemoryModule:
|
||||
"""
|
||||
existing = staging.read("module_1")
|
||||
spans = self._reconstruct_subtasks_from_rows(existing)
|
||||
already_planned: set[float] = {
|
||||
float(r["timestamp"]) for r in existing if r.get("style") == "plan"
|
||||
}
|
||||
already_planned: set[float] = {float(r["timestamp"]) for r in existing if r.get("style") == "plan"}
|
||||
new_rows = list(existing)
|
||||
|
||||
texts: list[str | None] = (
|
||||
@@ -280,14 +277,12 @@ class PlanSubtasksMemoryModule:
|
||||
if interjection_texts is None
|
||||
else [str(t) if t else None for t in interjection_texts]
|
||||
)
|
||||
for raw_t, inter_text in zip(interjection_times, texts):
|
||||
for raw_t, inter_text in zip(interjection_times, texts, strict=True):
|
||||
t = _snap_to_frame(raw_t, record.frame_timestamps)
|
||||
if t in already_planned:
|
||||
continue
|
||||
already_planned.add(t)
|
||||
plan_text = self._generate_plan(
|
||||
record, spans, refresh_t=t, interjection=inter_text
|
||||
)
|
||||
plan_text = self._generate_plan(record, spans, refresh_t=t, interjection=inter_text)
|
||||
if plan_text is not None:
|
||||
new_rows.append(
|
||||
{
|
||||
@@ -315,9 +310,7 @@ class PlanSubtasksMemoryModule:
|
||||
last_t = t
|
||||
return out
|
||||
|
||||
def _generate_subtasks(
|
||||
self, record: EpisodeRecord, *, task: str | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
def _generate_subtasks(self, record: EpisodeRecord, *, task: str | None = None) -> list[dict[str, Any]]:
|
||||
if record.row_count == 0 or not record.frame_timestamps:
|
||||
return []
|
||||
episode_duration = record.frame_timestamps[-1] - record.frame_timestamps[0]
|
||||
@@ -327,26 +320,7 @@ class PlanSubtasksMemoryModule:
|
||||
max_steps=self.config.plan_max_steps,
|
||||
episode_duration=f"{episode_duration:.3f}",
|
||||
)
|
||||
if self.config.use_video_url and isinstance(self.frame_provider, VideoFrameProvider):
|
||||
cache_dir = Path(self.frame_provider.root) / ".annotate_staging" / ".video_clips"
|
||||
clip = episode_clip_path(record, self.frame_provider, cache_dir)
|
||||
video_block = (
|
||||
to_video_url_block(f"file://{clip}", fps=self.config.use_video_url_fps)
|
||||
if clip is not None
|
||||
else []
|
||||
)
|
||||
else:
|
||||
target_count = max(
|
||||
1,
|
||||
int(round(episode_duration * self.config.frames_per_second)),
|
||||
)
|
||||
target_count = min(target_count, self.config.max_video_frames)
|
||||
video_frames = self.frame_provider.video_for_episode(record, target_count)
|
||||
video_block = to_video_block(video_frames)
|
||||
content = [*video_block, {"type": "text", "text": prompt}]
|
||||
messages = [{"role": "user", "content": content}]
|
||||
result = self.vlm.generate_json([messages])[0]
|
||||
spans = result.get("subtasks") if isinstance(result, dict) else None
|
||||
spans = self._vlm_field(self._video_message(record, prompt), "subtasks")
|
||||
if not spans:
|
||||
return []
|
||||
# clamp to [t0, t_last] and sort
|
||||
@@ -411,15 +385,9 @@ class PlanSubtasksMemoryModule:
|
||||
# where in the episode the plan stands so the re-emission
|
||||
# is grounded. Should be rare — plan refreshes are
|
||||
# interjection-driven by design.
|
||||
prompt += (
|
||||
f"\n\n(Plan refresh at t={refresh_t:.2f}s. Current "
|
||||
f"subtask: {current_subtask!r}.)\n"
|
||||
)
|
||||
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
|
||||
result = self.vlm.generate_json([messages])[0]
|
||||
if isinstance(result, dict) and isinstance(result.get("plan"), str):
|
||||
return result["plan"].strip()
|
||||
return None
|
||||
prompt += f"\n\n(Plan refresh at t={refresh_t:.2f}s. Current subtask: {current_subtask!r}.)\n"
|
||||
plan = self._vlm_field(self._text_message(prompt), "plan")
|
||||
return plan.strip() if isinstance(plan, str) else None
|
||||
|
||||
def _generate_memory(
|
||||
self,
|
||||
@@ -436,8 +404,5 @@ class PlanSubtasksMemoryModule:
|
||||
completed_subtask=completed,
|
||||
remaining_subtasks=", ".join(remaining) if remaining else "(none)",
|
||||
)
|
||||
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
|
||||
result = self.vlm.generate_json([messages])[0]
|
||||
if isinstance(result, dict) and isinstance(result.get("memory"), str):
|
||||
return result["memory"].strip()
|
||||
return ""
|
||||
memory = self._vlm_field(self._text_message(prompt), "memory")
|
||||
return memory.strip() if isinstance(memory, str) else ""
|
||||
|
||||
Reference in New Issue
Block a user