From a15e16c0721aeee195ce5e5b6f5851479d1ce418 Mon Sep 17 00:00:00 2001 From: pepijn Date: Sat, 23 May 2026 09:57:27 +0000 Subject: [PATCH] fix(annotate): replace fuzzy subtask snapping with strict match + one-shot retry MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The Jaccard-overlap snap was warping VLM output into wrong canonical labels — e.g. an off-vocab "consult the wizard" span would silently become "grasp blue cube" if that scored highest. Even with a higher floor the operator can't tell which subtasks were paraphrases vs genuine mislabels in the resulting dataset. Replace with strict exact-match validation + a single targeted retry: 1. Generate subtasks as before. 2. If any returned subtask's normalised form (lowercased, articles stripped, whitespace collapsed) isn't in the canonical vocab, fire one retry call naming the offending strings and re-sending the full canonical list. The retry prompt requires byte-identical output from the vocab. 3. After the retry, validate again. Spans still off-vocab are dropped — no fuzzy snapping ever produces a different canonical label than the VLM actually emitted. 4. If every span ends up off-vocab even after the retry, warn loudly so the operator extends ``meta/canonical_vocabulary.json`` to cover the missing phase. The episode is left with empty subtasks rather than silently fabricated ones — visibility > sweep-under- the-rug. Promote ``_NORMALIZE_STRIP_TOKENS`` to a class constant and split the normalisation helper out so the retry-validation and the final canonicalisation share one source of truth. Tests: - test_plan_module_accepts_article_only_difference: "grasp the blue cube" still maps to canonical "grasp blue cube" (article-tolerant). - test_plan_module_retries_when_subtask_off_vocab: paraphrase triggers the retry which the VLM corrects in pass 2. - test_plan_module_drops_off_vocab_subtask_after_retry: VLM that refuses to correct → bad span dropped, in-vocab span kept. - test_plan_module_empty_when_all_off_vocab_after_retry: every span off-vocab → episode left empty (no warping). All 13 vocabulary tests pass. Co-Authored-By: Claude Opus 4.7 (1M context) Co-authored-by: Cursor --- .../modules/plan_subtasks_memory.py | 208 ++++++++++-------- tests/annotations/test_vocabulary.py | 164 +++++++++----- 2 files changed, 220 insertions(+), 152 deletions(-) diff --git a/src/lerobot/annotations/steerable_pipeline/modules/plan_subtasks_memory.py b/src/lerobot/annotations/steerable_pipeline/modules/plan_subtasks_memory.py index ed226bd02..7d55a5d8d 100644 --- a/src/lerobot/annotations/steerable_pipeline/modules/plan_subtasks_memory.py +++ b/src/lerobot/annotations/steerable_pipeline/modules/plan_subtasks_memory.py @@ -322,13 +322,32 @@ class PlanSubtasksMemoryModule: episode_duration=f"{episode_duration:.3f}", vocabulary_block=self._subtask_vocabulary_block(), ) - spans = self._vlm_field(self._video_message(record, prompt), "subtasks") + messages = self._video_message(record, prompt) + spans = self._vlm_field(messages, "subtasks") + # When a vocabulary is in force, do a single targeted retry if + # any returned subtask is off-vocab — strict exact-match only, + # no fuzzy snapping. The retry includes the offending strings + # and the full canonical list so the VLM can correct itself. + if self.vocabulary is not None and self.vocabulary.subtasks and spans: + invalid = self._invalid_subtasks(spans) + if invalid: + logger.info( + "episode %d: VLM emitted %d off-vocab subtask(s) (%s); retrying once", + record.episode_index, + len(invalid), + invalid, + ) + retry_msg = self._build_subtask_retry_message(messages, invalid) + retried = self._vlm_field(retry_msg, "subtasks") + if retried: + spans = retried + if not spans: return [] # clamp to [t0, t_last] and sort t0 = record.frame_timestamps[0] t_last = record.frame_timestamps[-1] - raw: list[dict[str, Any]] = [] + cleaned: list[dict[str, Any]] = [] for span in spans: try: start = float(span["start"]) @@ -340,46 +359,20 @@ class PlanSubtasksMemoryModule: end = max(t0, min(end, t_last)) if end < start: start, end = end, start - if text: - raw.append({"text": text, "start": start, "end": end}) - - # Without a vocabulary, free-form spans pass through unchanged. - if self.vocabulary is None or not self.vocabulary.subtasks: - raw.sort(key=lambda s: s["start"]) - return raw - - # With a vocabulary, snap each span to the closest canonical - # label. Two-pass: first try the normal Jaccard floor (drops - # off-topic hallucinations); if that leaves the episode with - # zero subtasks, fall back to snap-without-floor so the episode - # is never silently emptied — a wrong canonical label is still - # closer to the right phase than nothing at all. - cleaned: list[dict[str, Any]] = [] - for span in raw: - mapped = self._canonicalize_subtask(span["text"]) - if mapped: - cleaned.append({**span, "text": mapped}) - if not cleaned and raw: - logger.warning( - "episode %d: every VLM subtask was off-vocabulary " - "(%d spans); snapping to closest canonical label anyway " - "(check meta/canonical_vocabulary.json for missing phases)", - record.episode_index, - len(raw), - ) - for span in raw: - mapped = self._canonicalize_subtask(span["text"], force=True) - if mapped: - cleaned.append({**span, "text": mapped}) - elif len(cleaned) < len(raw): - logger.info( - "episode %d: %d/%d subtasks survived canonicalisation; " - "the rest were off-vocabulary", - record.episode_index, - len(cleaned), - len(raw), - ) + if not text: + continue + text = self._canonicalize_subtask(text) + if not text: + continue + cleaned.append({"text": text, "start": start, "end": end}) cleaned.sort(key=lambda s: s["start"]) + if self.vocabulary is not None and self.vocabulary.subtasks and not cleaned: + logger.warning( + "episode %d: every VLM subtask was off-vocab even after retry — " + "episode left empty (extend meta/canonical_vocabulary.json to " + "cover the missing phase)", + record.episode_index, + ) return cleaned # ------------------------------------------------------------------ @@ -420,70 +413,93 @@ class PlanSubtasksMemoryModule: f"{bullets}\n\n" ) - _CANONICALIZE_JACCARD_FLOOR: float = 0.25 - _CANONICALIZE_IGNORE_TOKENS: frozenset[str] = frozenset( - {"the", "a", "an", "to", "into", "from", "of", "on", "over", "at"} - ) + _NORMALIZE_STRIP_TOKENS: frozenset[str] = frozenset({"the", "a", "an"}) - def _canonicalize_subtask(self, text: str, *, force: bool = False) -> str: - """Snap ``text`` to the closest canonical subtask string. + def _canonicalize_subtask(self, text: str) -> str: + """Validate ``text`` against the canonical vocabulary; no fuzzy snap. Without a vocabulary, the original text passes through. With a - vocabulary, an exact case-insensitive match wins; failing that, - the best Jaccard overlap on the word set is used as a tolerant - fuzzy match (handles articles / minor reorderings). + vocabulary, accept the span only if its normalised form (lower- + cased, articles stripped, whitespace collapsed) matches a + canonical entry exactly — the canonical wording is returned so + the supervised string is byte-identical across episodes. - Behaviour at the Jaccard floor depends on ``force``: - - ``force=False`` (default): below ``_CANONICALIZE_JACCARD_FLOOR`` - the subtask is dropped. ``_generate_subtasks`` runs this first - to filter genuine off-topic hallucinations. - - ``force=True``: always snap, no floor. ``_generate_subtasks`` - uses this in a second pass when the first pass would otherwise - empty the episode — a slightly-wrong canonical label is still - closer to the right phase than no subtask at all, which makes - the whole episode invisible to the downstream policy. + Off-vocab spans are dropped (empty string). Upstream + ``_generate_subtasks`` triggers a targeted retry before reaching + the drop path; this function never snaps or warps a span into + a different label. """ if self.vocabulary is None or not self.vocabulary.subtasks: return text.strip() - candidates = self.vocabulary.subtasks - cleaned = text.strip() - lowered = cleaned.lower() - for candidate in candidates: - if candidate.lower() == lowered: + normalised = self._normalize(text) + if not normalised: + return "" + for candidate in self.vocabulary.subtasks: + if self._normalize(candidate) == normalised: return candidate - # Jaccard fallback: token-set overlap, drop articles + adverbs. - words = { - w for w in lowered.replace(",", " ").split() - if w and w not in self._CANONICALIZE_IGNORE_TOKENS - } - if not words: - return "" - best: tuple[float, str] | None = None - for candidate in candidates: - cand_words = { - w for w in candidate.lower().replace(",", " ").split() - if w and w not in self._CANONICALIZE_IGNORE_TOKENS - } - if not cand_words: + return "" + + @classmethod + def _normalize(cls, text: str) -> str: + """Lowercase, strip articles, collapse whitespace, drop punctuation.""" + words = [ + w.strip(".,:;\"'!?()") + for w in text.lower().replace(",", " ").split() + ] + return " ".join(w for w in words if w and w not in cls._NORMALIZE_STRIP_TOKENS) + + def _invalid_subtasks(self, spans: list[dict[str, Any]]) -> list[str]: + """Return the unique off-vocab subtask strings the VLM produced.""" + seen: list[str] = [] + for span in spans: + text = str((span or {}).get("text") or "").strip() + if not text: continue - inter = len(words & cand_words) - union = len(words | cand_words) - score = inter / union if union else 0.0 - if best is None or score > best[0]: - best = (score, candidate) - if best is None: - return "" - if not force and best[0] < self._CANONICALIZE_JACCARD_FLOOR: - logger.info( - "subtask %r dropped — best canonical match %r scored %.2f " - "(< %.2f Jaccard floor)", - cleaned, - best[1], - best[0], - self._CANONICALIZE_JACCARD_FLOOR, - ) - return "" - return best[1] + if self._canonicalize_subtask(text): + continue + if text not in seen: + seen.append(text) + return seen + + def _build_subtask_retry_message( + self, original_messages: list[dict[str, Any]], invalid: list[str] + ) -> list[dict[str, Any]]: + """Compose a one-shot correction prompt naming the off-vocab strings.""" + assert self.vocabulary is not None + canonical = "\n".join(f"- {s}" for s in self.vocabulary.subtasks) + invalid_list = "\n".join(f"- {s!r}" for s in invalid) + correction = ( + "Your previous response included subtask labels that are NOT in " + "the canonical vocabulary:\n" + f"{invalid_list}\n\n" + "Re-emit the same segmentation (same number of spans, same start/end " + "timestamps where they were valid) but replace every off-vocab " + "label with the EXACT canonical string for that phase, copied " + "verbatim from this list:\n" + f"{canonical}\n\n" + "Strict rules:\n" + "- Output strings must be byte-for-byte identical to entries above.\n" + "- No articles, no adverbs, no extra words.\n" + "- If a phase truly has no canonical match, omit that span entirely.\n" + "Return the same JSON shape as before." + ) + # Append the correction as an additional user turn; the model + # sees the original prompt + its prior output is implied by the + # conversation context (the VLM client is stateless, so we + # re-send the original content plus this correction). + retry_messages = [ + { + "role": m.get("role", "user"), + "content": ( + m.get("content") + if isinstance(m.get("content"), str) + else list(m.get("content") or []) + ), + } + for m in original_messages + ] + retry_messages.append({"role": "user", "content": correction}) + return retry_messages def _generate_plan( self, diff --git a/tests/annotations/test_vocabulary.py b/tests/annotations/test_vocabulary.py index 20d22a50d..1f1c046fe 100644 --- a/tests/annotations/test_vocabulary.py +++ b/tests/annotations/test_vocabulary.py @@ -182,59 +182,17 @@ def test_plan_module_inlines_vocab_into_subtask_prompt( assert any("grasp blue cube" in t for t in captured) -def test_plan_module_canonicalizes_paraphrased_subtask( +def test_plan_module_accepts_article_only_difference( fixture_dataset_root: Path, tmp_path: Path ) -> None: - """Off-vocab paraphrase with high token overlap snaps to canonical form.""" + """Articles like 'the'/'a'/'an' are stripped during validation.""" from lerobot.annotations.steerable_pipeline.vlm_client import StubVlmClient def responder(_messages): return { "subtasks": [ - # paraphrase of "grasp blue cube" — overlapping tokens + # Same canonical phrase modulo "the" — should be accepted. {"text": "grasp the blue cube", "start": 0.0, "end": 0.4}, - # paraphrase of "place blue cube in box" — high overlap - {"text": "place the blue cube into the box", "start": 0.4, "end": 0.9}, - ] - } - - vlm = StubVlmClient(responder=responder) - vocab = Vocabulary(subtasks=_CANONICAL_SUBTASKS, memory_milestones=_CANONICAL_MEMORY) - module = PlanSubtasksMemoryModule( - vlm=vlm, - config=PlanConfig(n_task_rephrasings=0), - vocabulary=vocab, - ) - record = next(iter_episodes(fixture_dataset_root)) - staging = EpisodeStaging(tmp_path / "stage", record.episode_index) - module.run_episode(record, staging) - rows = staging.read("plan") - subtask_texts = [r["content"] for r in rows if r["style"] == "subtask"] - # Both paraphrases snapped onto canonical strings. - assert subtask_texts == ["grasp blue cube", "place blue cube in box"] - - -def test_plan_module_drops_off_vocab_subtask( - fixture_dataset_root: Path, tmp_path: Path -) -> None: - """A subtask with low overlap to every canonical label is dropped. - - Drop only kicks in when *at least one* other subtask survives — if - every span would be dropped the episode would come out empty, so - ``_generate_subtasks`` falls back to snap-without-floor; that path - is exercised by ``test_plan_module_snaps_when_all_off_vocab``. - """ - from lerobot.annotations.steerable_pipeline.vlm_client import StubVlmClient - - def responder(_messages): - return { - "subtasks": [ - # in-vocab — keeps the episode non-empty so the floor - # is allowed to drop the next span. - {"text": "grasp blue cube", "start": 0.0, "end": 0.4}, - # off-vocab hallucination — no token overlap above the - # Jaccard floor; should be dropped. - {"text": "perform a fancy macarena dance", "start": 0.4, "end": 0.9}, ] } @@ -253,18 +211,114 @@ def test_plan_module_drops_off_vocab_subtask( assert subtask_texts == ["grasp blue cube"] -def test_plan_module_snaps_when_all_off_vocab( +def test_plan_module_retries_when_subtask_off_vocab( fixture_dataset_root: Path, tmp_path: Path ) -> None: - """All-off-vocab spans snap to nearest canonical instead of emptying the episode.""" + """One-shot retry replaces an off-vocab paraphrase with the canonical form.""" + from lerobot.annotations.steerable_pipeline.vlm_client import StubVlmClient + + call_count = {"n": 0} + + def responder(messages): + call_count["n"] += 1 + # First call: returns an off-vocab paraphrase. + if call_count["n"] == 1: + return { + "subtasks": [ + # paraphrase, not in vocab + {"text": "pick up blue cube", "start": 0.0, "end": 0.4}, + ] + } + # Second call (the retry): should contain the correction prompt; + # respond with the canonical phrase exactly. + last_user_text = "" + for message in messages: + content = message.get("content") + if isinstance(content, str): + last_user_text = content + elif isinstance(content, list): + for block in content: + if isinstance(block, dict) and block.get("type") == "text": + last_user_text = block.get("text", "") + assert "NOT in the canonical vocabulary" in last_user_text + return { + "subtasks": [ + {"text": "grasp blue cube", "start": 0.0, "end": 0.4}, + ] + } + + vlm = StubVlmClient(responder=responder) + vocab = Vocabulary(subtasks=_CANONICAL_SUBTASKS, memory_milestones=_CANONICAL_MEMORY) + module = PlanSubtasksMemoryModule( + vlm=vlm, + config=PlanConfig(n_task_rephrasings=0), + vocabulary=vocab, + ) + record = next(iter_episodes(fixture_dataset_root)) + staging = EpisodeStaging(tmp_path / "stage", record.episode_index) + module.run_episode(record, staging) + rows = staging.read("plan") + subtask_texts = [r["content"] for r in rows if r["style"] == "subtask"] + assert subtask_texts == ["grasp blue cube"] + # The retry must have fired exactly once. + assert call_count["n"] == 2 + + +def test_plan_module_drops_off_vocab_subtask_after_retry( + fixture_dataset_root: Path, tmp_path: Path +) -> None: + """If the VLM stays off-vocab even after the retry, the bad span is dropped.""" + from lerobot.annotations.steerable_pipeline.vlm_client import StubVlmClient + + call_count = {"n": 0} + + def responder(_messages): + call_count["n"] += 1 + # Both calls return the same off-vocab span — the model can't + # be corrected. The second call also returns one in-vocab span + # so the episode isn't empty; this lets us check that the + # off-vocab span is dropped without affecting the in-vocab one. + if call_count["n"] == 1: + return { + "subtasks": [ + {"text": "perform a fancy macarena dance", "start": 0.0, "end": 0.4}, + {"text": "grasp blue cube", "start": 0.4, "end": 0.9}, + ] + } + return { + "subtasks": [ + {"text": "perform a fancy macarena dance", "start": 0.0, "end": 0.4}, + {"text": "grasp blue cube", "start": 0.4, "end": 0.9}, + ] + } + + vlm = StubVlmClient(responder=responder) + vocab = Vocabulary(subtasks=_CANONICAL_SUBTASKS, memory_milestones=_CANONICAL_MEMORY) + module = PlanSubtasksMemoryModule( + vlm=vlm, + config=PlanConfig(n_task_rephrasings=0), + vocabulary=vocab, + ) + record = next(iter_episodes(fixture_dataset_root)) + staging = EpisodeStaging(tmp_path / "stage", record.episode_index) + module.run_episode(record, staging) + rows = staging.read("plan") + subtask_texts = [r["content"] for r in rows if r["style"] == "subtask"] + # Retry fired exactly once; bad span dropped, good span kept. + assert call_count["n"] == 2 + assert subtask_texts == ["grasp blue cube"] + + +def test_plan_module_empty_when_all_off_vocab_after_retry( + fixture_dataset_root: Path, tmp_path: Path +) -> None: + """All-off-vocab spans → episode comes out empty (no silent fuzzy snap).""" from lerobot.annotations.steerable_pipeline.vlm_client import StubVlmClient def responder(_messages): + # Returns the same off-vocab spans on both attempts. return { "subtasks": [ - # Both off-vocab — would normally be dropped. The - # fallback should snap each to its best canonical match - # rather than leave the episode with no subtasks at all. {"text": "make a smoothie", "start": 0.0, "end": 0.4}, {"text": "consult the wizard", "start": 0.4, "end": 0.9}, ] @@ -282,12 +336,10 @@ def test_plan_module_snaps_when_all_off_vocab( module.run_episode(record, staging) rows = staging.read("plan") subtask_texts = [r["content"] for r in rows if r["style"] == "subtask"] - # Two off-vocab spans → two canonical subtasks (snapped to nearest - # by Jaccard with no floor). The exact canonical choice doesn't - # matter — only that the episode came out with subtasks rather - # than empty. - assert len(subtask_texts) == 2 - assert all(s in _CANONICAL_SUBTASKS for s in subtask_texts) + # No subtask gets fabricated — better to leave the episode empty + # so the operator notices the vocabulary gap than to silently + # warp the labels. + assert subtask_texts == [] def test_plan_module_without_vocab_passes_through(