diff --git a/src/lerobot/annotations/steerable_pipeline/modules/plan_subtasks_memory.py b/src/lerobot/annotations/steerable_pipeline/modules/plan_subtasks_memory.py index ed226bd02..7d55a5d8d 100644 --- a/src/lerobot/annotations/steerable_pipeline/modules/plan_subtasks_memory.py +++ b/src/lerobot/annotations/steerable_pipeline/modules/plan_subtasks_memory.py @@ -322,13 +322,32 @@ class PlanSubtasksMemoryModule: episode_duration=f"{episode_duration:.3f}", vocabulary_block=self._subtask_vocabulary_block(), ) - spans = self._vlm_field(self._video_message(record, prompt), "subtasks") + messages = self._video_message(record, prompt) + spans = self._vlm_field(messages, "subtasks") + # When a vocabulary is in force, do a single targeted retry if + # any returned subtask is off-vocab — strict exact-match only, + # no fuzzy snapping. The retry includes the offending strings + # and the full canonical list so the VLM can correct itself. + if self.vocabulary is not None and self.vocabulary.subtasks and spans: + invalid = self._invalid_subtasks(spans) + if invalid: + logger.info( + "episode %d: VLM emitted %d off-vocab subtask(s) (%s); retrying once", + record.episode_index, + len(invalid), + invalid, + ) + retry_msg = self._build_subtask_retry_message(messages, invalid) + retried = self._vlm_field(retry_msg, "subtasks") + if retried: + spans = retried + if not spans: return [] # clamp to [t0, t_last] and sort t0 = record.frame_timestamps[0] t_last = record.frame_timestamps[-1] - raw: list[dict[str, Any]] = [] + cleaned: list[dict[str, Any]] = [] for span in spans: try: start = float(span["start"]) @@ -340,46 +359,20 @@ class PlanSubtasksMemoryModule: end = max(t0, min(end, t_last)) if end < start: start, end = end, start - if text: - raw.append({"text": text, "start": start, "end": end}) - - # Without a vocabulary, free-form spans pass through unchanged. - if self.vocabulary is None or not self.vocabulary.subtasks: - raw.sort(key=lambda s: s["start"]) - return raw - - # With a vocabulary, snap each span to the closest canonical - # label. Two-pass: first try the normal Jaccard floor (drops - # off-topic hallucinations); if that leaves the episode with - # zero subtasks, fall back to snap-without-floor so the episode - # is never silently emptied — a wrong canonical label is still - # closer to the right phase than nothing at all. - cleaned: list[dict[str, Any]] = [] - for span in raw: - mapped = self._canonicalize_subtask(span["text"]) - if mapped: - cleaned.append({**span, "text": mapped}) - if not cleaned and raw: - logger.warning( - "episode %d: every VLM subtask was off-vocabulary " - "(%d spans); snapping to closest canonical label anyway " - "(check meta/canonical_vocabulary.json for missing phases)", - record.episode_index, - len(raw), - ) - for span in raw: - mapped = self._canonicalize_subtask(span["text"], force=True) - if mapped: - cleaned.append({**span, "text": mapped}) - elif len(cleaned) < len(raw): - logger.info( - "episode %d: %d/%d subtasks survived canonicalisation; " - "the rest were off-vocabulary", - record.episode_index, - len(cleaned), - len(raw), - ) + if not text: + continue + text = self._canonicalize_subtask(text) + if not text: + continue + cleaned.append({"text": text, "start": start, "end": end}) cleaned.sort(key=lambda s: s["start"]) + if self.vocabulary is not None and self.vocabulary.subtasks and not cleaned: + logger.warning( + "episode %d: every VLM subtask was off-vocab even after retry — " + "episode left empty (extend meta/canonical_vocabulary.json to " + "cover the missing phase)", + record.episode_index, + ) return cleaned # ------------------------------------------------------------------ @@ -420,70 +413,93 @@ class PlanSubtasksMemoryModule: f"{bullets}\n\n" ) - _CANONICALIZE_JACCARD_FLOOR: float = 0.25 - _CANONICALIZE_IGNORE_TOKENS: frozenset[str] = frozenset( - {"the", "a", "an", "to", "into", "from", "of", "on", "over", "at"} - ) + _NORMALIZE_STRIP_TOKENS: frozenset[str] = frozenset({"the", "a", "an"}) - def _canonicalize_subtask(self, text: str, *, force: bool = False) -> str: - """Snap ``text`` to the closest canonical subtask string. + def _canonicalize_subtask(self, text: str) -> str: + """Validate ``text`` against the canonical vocabulary; no fuzzy snap. Without a vocabulary, the original text passes through. With a - vocabulary, an exact case-insensitive match wins; failing that, - the best Jaccard overlap on the word set is used as a tolerant - fuzzy match (handles articles / minor reorderings). + vocabulary, accept the span only if its normalised form (lower- + cased, articles stripped, whitespace collapsed) matches a + canonical entry exactly — the canonical wording is returned so + the supervised string is byte-identical across episodes. - Behaviour at the Jaccard floor depends on ``force``: - - ``force=False`` (default): below ``_CANONICALIZE_JACCARD_FLOOR`` - the subtask is dropped. ``_generate_subtasks`` runs this first - to filter genuine off-topic hallucinations. - - ``force=True``: always snap, no floor. ``_generate_subtasks`` - uses this in a second pass when the first pass would otherwise - empty the episode — a slightly-wrong canonical label is still - closer to the right phase than no subtask at all, which makes - the whole episode invisible to the downstream policy. + Off-vocab spans are dropped (empty string). Upstream + ``_generate_subtasks`` triggers a targeted retry before reaching + the drop path; this function never snaps or warps a span into + a different label. """ if self.vocabulary is None or not self.vocabulary.subtasks: return text.strip() - candidates = self.vocabulary.subtasks - cleaned = text.strip() - lowered = cleaned.lower() - for candidate in candidates: - if candidate.lower() == lowered: + normalised = self._normalize(text) + if not normalised: + return "" + for candidate in self.vocabulary.subtasks: + if self._normalize(candidate) == normalised: return candidate - # Jaccard fallback: token-set overlap, drop articles + adverbs. - words = { - w for w in lowered.replace(",", " ").split() - if w and w not in self._CANONICALIZE_IGNORE_TOKENS - } - if not words: - return "" - best: tuple[float, str] | None = None - for candidate in candidates: - cand_words = { - w for w in candidate.lower().replace(",", " ").split() - if w and w not in self._CANONICALIZE_IGNORE_TOKENS - } - if not cand_words: + return "" + + @classmethod + def _normalize(cls, text: str) -> str: + """Lowercase, strip articles, collapse whitespace, drop punctuation.""" + words = [ + w.strip(".,:;\"'!?()") + for w in text.lower().replace(",", " ").split() + ] + return " ".join(w for w in words if w and w not in cls._NORMALIZE_STRIP_TOKENS) + + def _invalid_subtasks(self, spans: list[dict[str, Any]]) -> list[str]: + """Return the unique off-vocab subtask strings the VLM produced.""" + seen: list[str] = [] + for span in spans: + text = str((span or {}).get("text") or "").strip() + if not text: continue - inter = len(words & cand_words) - union = len(words | cand_words) - score = inter / union if union else 0.0 - if best is None or score > best[0]: - best = (score, candidate) - if best is None: - return "" - if not force and best[0] < self._CANONICALIZE_JACCARD_FLOOR: - logger.info( - "subtask %r dropped — best canonical match %r scored %.2f " - "(< %.2f Jaccard floor)", - cleaned, - best[1], - best[0], - self._CANONICALIZE_JACCARD_FLOOR, - ) - return "" - return best[1] + if self._canonicalize_subtask(text): + continue + if text not in seen: + seen.append(text) + return seen + + def _build_subtask_retry_message( + self, original_messages: list[dict[str, Any]], invalid: list[str] + ) -> list[dict[str, Any]]: + """Compose a one-shot correction prompt naming the off-vocab strings.""" + assert self.vocabulary is not None + canonical = "\n".join(f"- {s}" for s in self.vocabulary.subtasks) + invalid_list = "\n".join(f"- {s!r}" for s in invalid) + correction = ( + "Your previous response included subtask labels that are NOT in " + "the canonical vocabulary:\n" + f"{invalid_list}\n\n" + "Re-emit the same segmentation (same number of spans, same start/end " + "timestamps where they were valid) but replace every off-vocab " + "label with the EXACT canonical string for that phase, copied " + "verbatim from this list:\n" + f"{canonical}\n\n" + "Strict rules:\n" + "- Output strings must be byte-for-byte identical to entries above.\n" + "- No articles, no adverbs, no extra words.\n" + "- If a phase truly has no canonical match, omit that span entirely.\n" + "Return the same JSON shape as before." + ) + # Append the correction as an additional user turn; the model + # sees the original prompt + its prior output is implied by the + # conversation context (the VLM client is stateless, so we + # re-send the original content plus this correction). + retry_messages = [ + { + "role": m.get("role", "user"), + "content": ( + m.get("content") + if isinstance(m.get("content"), str) + else list(m.get("content") or []) + ), + } + for m in original_messages + ] + retry_messages.append({"role": "user", "content": correction}) + return retry_messages def _generate_plan( self, diff --git a/tests/annotations/test_vocabulary.py b/tests/annotations/test_vocabulary.py index 20d22a50d..1f1c046fe 100644 --- a/tests/annotations/test_vocabulary.py +++ b/tests/annotations/test_vocabulary.py @@ -182,59 +182,17 @@ def test_plan_module_inlines_vocab_into_subtask_prompt( assert any("grasp blue cube" in t for t in captured) -def test_plan_module_canonicalizes_paraphrased_subtask( +def test_plan_module_accepts_article_only_difference( fixture_dataset_root: Path, tmp_path: Path ) -> None: - """Off-vocab paraphrase with high token overlap snaps to canonical form.""" + """Articles like 'the'/'a'/'an' are stripped during validation.""" from lerobot.annotations.steerable_pipeline.vlm_client import StubVlmClient def responder(_messages): return { "subtasks": [ - # paraphrase of "grasp blue cube" — overlapping tokens + # Same canonical phrase modulo "the" — should be accepted. {"text": "grasp the blue cube", "start": 0.0, "end": 0.4}, - # paraphrase of "place blue cube in box" — high overlap - {"text": "place the blue cube into the box", "start": 0.4, "end": 0.9}, - ] - } - - vlm = StubVlmClient(responder=responder) - vocab = Vocabulary(subtasks=_CANONICAL_SUBTASKS, memory_milestones=_CANONICAL_MEMORY) - module = PlanSubtasksMemoryModule( - vlm=vlm, - config=PlanConfig(n_task_rephrasings=0), - vocabulary=vocab, - ) - record = next(iter_episodes(fixture_dataset_root)) - staging = EpisodeStaging(tmp_path / "stage", record.episode_index) - module.run_episode(record, staging) - rows = staging.read("plan") - subtask_texts = [r["content"] for r in rows if r["style"] == "subtask"] - # Both paraphrases snapped onto canonical strings. - assert subtask_texts == ["grasp blue cube", "place blue cube in box"] - - -def test_plan_module_drops_off_vocab_subtask( - fixture_dataset_root: Path, tmp_path: Path -) -> None: - """A subtask with low overlap to every canonical label is dropped. - - Drop only kicks in when *at least one* other subtask survives — if - every span would be dropped the episode would come out empty, so - ``_generate_subtasks`` falls back to snap-without-floor; that path - is exercised by ``test_plan_module_snaps_when_all_off_vocab``. - """ - from lerobot.annotations.steerable_pipeline.vlm_client import StubVlmClient - - def responder(_messages): - return { - "subtasks": [ - # in-vocab — keeps the episode non-empty so the floor - # is allowed to drop the next span. - {"text": "grasp blue cube", "start": 0.0, "end": 0.4}, - # off-vocab hallucination — no token overlap above the - # Jaccard floor; should be dropped. - {"text": "perform a fancy macarena dance", "start": 0.4, "end": 0.9}, ] } @@ -253,18 +211,114 @@ def test_plan_module_drops_off_vocab_subtask( assert subtask_texts == ["grasp blue cube"] -def test_plan_module_snaps_when_all_off_vocab( +def test_plan_module_retries_when_subtask_off_vocab( fixture_dataset_root: Path, tmp_path: Path ) -> None: - """All-off-vocab spans snap to nearest canonical instead of emptying the episode.""" + """One-shot retry replaces an off-vocab paraphrase with the canonical form.""" + from lerobot.annotations.steerable_pipeline.vlm_client import StubVlmClient + + call_count = {"n": 0} + + def responder(messages): + call_count["n"] += 1 + # First call: returns an off-vocab paraphrase. + if call_count["n"] == 1: + return { + "subtasks": [ + # paraphrase, not in vocab + {"text": "pick up blue cube", "start": 0.0, "end": 0.4}, + ] + } + # Second call (the retry): should contain the correction prompt; + # respond with the canonical phrase exactly. + last_user_text = "" + for message in messages: + content = message.get("content") + if isinstance(content, str): + last_user_text = content + elif isinstance(content, list): + for block in content: + if isinstance(block, dict) and block.get("type") == "text": + last_user_text = block.get("text", "") + assert "NOT in the canonical vocabulary" in last_user_text + return { + "subtasks": [ + {"text": "grasp blue cube", "start": 0.0, "end": 0.4}, + ] + } + + vlm = StubVlmClient(responder=responder) + vocab = Vocabulary(subtasks=_CANONICAL_SUBTASKS, memory_milestones=_CANONICAL_MEMORY) + module = PlanSubtasksMemoryModule( + vlm=vlm, + config=PlanConfig(n_task_rephrasings=0), + vocabulary=vocab, + ) + record = next(iter_episodes(fixture_dataset_root)) + staging = EpisodeStaging(tmp_path / "stage", record.episode_index) + module.run_episode(record, staging) + rows = staging.read("plan") + subtask_texts = [r["content"] for r in rows if r["style"] == "subtask"] + assert subtask_texts == ["grasp blue cube"] + # The retry must have fired exactly once. + assert call_count["n"] == 2 + + +def test_plan_module_drops_off_vocab_subtask_after_retry( + fixture_dataset_root: Path, tmp_path: Path +) -> None: + """If the VLM stays off-vocab even after the retry, the bad span is dropped.""" + from lerobot.annotations.steerable_pipeline.vlm_client import StubVlmClient + + call_count = {"n": 0} + + def responder(_messages): + call_count["n"] += 1 + # Both calls return the same off-vocab span — the model can't + # be corrected. The second call also returns one in-vocab span + # so the episode isn't empty; this lets us check that the + # off-vocab span is dropped without affecting the in-vocab one. + if call_count["n"] == 1: + return { + "subtasks": [ + {"text": "perform a fancy macarena dance", "start": 0.0, "end": 0.4}, + {"text": "grasp blue cube", "start": 0.4, "end": 0.9}, + ] + } + return { + "subtasks": [ + {"text": "perform a fancy macarena dance", "start": 0.0, "end": 0.4}, + {"text": "grasp blue cube", "start": 0.4, "end": 0.9}, + ] + } + + vlm = StubVlmClient(responder=responder) + vocab = Vocabulary(subtasks=_CANONICAL_SUBTASKS, memory_milestones=_CANONICAL_MEMORY) + module = PlanSubtasksMemoryModule( + vlm=vlm, + config=PlanConfig(n_task_rephrasings=0), + vocabulary=vocab, + ) + record = next(iter_episodes(fixture_dataset_root)) + staging = EpisodeStaging(tmp_path / "stage", record.episode_index) + module.run_episode(record, staging) + rows = staging.read("plan") + subtask_texts = [r["content"] for r in rows if r["style"] == "subtask"] + # Retry fired exactly once; bad span dropped, good span kept. + assert call_count["n"] == 2 + assert subtask_texts == ["grasp blue cube"] + + +def test_plan_module_empty_when_all_off_vocab_after_retry( + fixture_dataset_root: Path, tmp_path: Path +) -> None: + """All-off-vocab spans → episode comes out empty (no silent fuzzy snap).""" from lerobot.annotations.steerable_pipeline.vlm_client import StubVlmClient def responder(_messages): + # Returns the same off-vocab spans on both attempts. return { "subtasks": [ - # Both off-vocab — would normally be dropped. The - # fallback should snap each to its best canonical match - # rather than leave the episode with no subtasks at all. {"text": "make a smoothie", "start": 0.0, "end": 0.4}, {"text": "consult the wizard", "start": 0.4, "end": 0.9}, ] @@ -282,12 +336,10 @@ def test_plan_module_snaps_when_all_off_vocab( module.run_episode(record, staging) rows = staging.read("plan") subtask_texts = [r["content"] for r in rows if r["style"] == "subtask"] - # Two off-vocab spans → two canonical subtasks (snapped to nearest - # by Jaccard with no floor). The exact canonical choice doesn't - # matter — only that the episode came out with subtasks rather - # than empty. - assert len(subtask_texts) == 2 - assert all(s in _CANONICAL_SUBTASKS for s in subtask_texts) + # No subtask gets fabricated — better to leave the episode empty + # so the operator notices the vocabulary gap than to silently + # warp the labels. + assert subtask_texts == [] def test_plan_module_without_vocab_passes_through(