fix(annotate): replace fuzzy subtask snapping with strict match + one-shot retry

The Jaccard-overlap snap was warping VLM output into wrong canonical
labels — e.g. an off-vocab "consult the wizard" span would silently
become "grasp blue cube" if that scored highest. Even with a higher
floor the operator can't tell which subtasks were paraphrases vs
genuine mislabels in the resulting dataset.

Replace with strict exact-match validation + a single targeted retry:

  1. Generate subtasks as before.
  2. If any returned subtask's normalised form (lowercased, articles
     stripped, whitespace collapsed) isn't in the canonical vocab,
     fire one retry call naming the offending strings and re-sending
     the full canonical list. The retry prompt requires byte-identical
     output from the vocab.
  3. After the retry, validate again. Spans still off-vocab are
     dropped — no fuzzy snapping ever produces a different canonical
     label than the VLM actually emitted.
  4. If every span ends up off-vocab even after the retry, warn loudly
     so the operator extends ``meta/canonical_vocabulary.json`` to
     cover the missing phase. The episode is left with empty subtasks
     rather than silently fabricated ones — visibility > sweep-under-
     the-rug.

Promote ``_NORMALIZE_STRIP_TOKENS`` to a class constant and split the
normalisation helper out so the retry-validation and the final
canonicalisation share one source of truth.

Tests:
  - test_plan_module_accepts_article_only_difference: "grasp the blue
    cube" still maps to canonical "grasp blue cube" (article-tolerant).
  - test_plan_module_retries_when_subtask_off_vocab: paraphrase
    triggers the retry which the VLM corrects in pass 2.
  - test_plan_module_drops_off_vocab_subtask_after_retry: VLM that
    refuses to correct → bad span dropped, in-vocab span kept.
  - test_plan_module_empty_when_all_off_vocab_after_retry: every
    span off-vocab → episode left empty (no warping).
All 13 vocabulary tests pass.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
pepijn
2026-05-23 09:57:27 +00:00
parent 336af85c09
commit a15e16c072
2 changed files with 220 additions and 152 deletions
@@ -322,13 +322,32 @@ class PlanSubtasksMemoryModule:
episode_duration=f"{episode_duration:.3f}", episode_duration=f"{episode_duration:.3f}",
vocabulary_block=self._subtask_vocabulary_block(), vocabulary_block=self._subtask_vocabulary_block(),
) )
spans = self._vlm_field(self._video_message(record, prompt), "subtasks") messages = self._video_message(record, prompt)
spans = self._vlm_field(messages, "subtasks")
# When a vocabulary is in force, do a single targeted retry if
# any returned subtask is off-vocab — strict exact-match only,
# no fuzzy snapping. The retry includes the offending strings
# and the full canonical list so the VLM can correct itself.
if self.vocabulary is not None and self.vocabulary.subtasks and spans:
invalid = self._invalid_subtasks(spans)
if invalid:
logger.info(
"episode %d: VLM emitted %d off-vocab subtask(s) (%s); retrying once",
record.episode_index,
len(invalid),
invalid,
)
retry_msg = self._build_subtask_retry_message(messages, invalid)
retried = self._vlm_field(retry_msg, "subtasks")
if retried:
spans = retried
if not spans: if not spans:
return [] return []
# clamp to [t0, t_last] and sort # clamp to [t0, t_last] and sort
t0 = record.frame_timestamps[0] t0 = record.frame_timestamps[0]
t_last = record.frame_timestamps[-1] t_last = record.frame_timestamps[-1]
raw: list[dict[str, Any]] = [] cleaned: list[dict[str, Any]] = []
for span in spans: for span in spans:
try: try:
start = float(span["start"]) start = float(span["start"])
@@ -340,46 +359,20 @@ class PlanSubtasksMemoryModule:
end = max(t0, min(end, t_last)) end = max(t0, min(end, t_last))
if end < start: if end < start:
start, end = end, start start, end = end, start
if text: if not text:
raw.append({"text": text, "start": start, "end": end}) continue
text = self._canonicalize_subtask(text)
# Without a vocabulary, free-form spans pass through unchanged. if not text:
if self.vocabulary is None or not self.vocabulary.subtasks: continue
raw.sort(key=lambda s: s["start"]) cleaned.append({"text": text, "start": start, "end": end})
return raw
# With a vocabulary, snap each span to the closest canonical
# label. Two-pass: first try the normal Jaccard floor (drops
# off-topic hallucinations); if that leaves the episode with
# zero subtasks, fall back to snap-without-floor so the episode
# is never silently emptied — a wrong canonical label is still
# closer to the right phase than nothing at all.
cleaned: list[dict[str, Any]] = []
for span in raw:
mapped = self._canonicalize_subtask(span["text"])
if mapped:
cleaned.append({**span, "text": mapped})
if not cleaned and raw:
logger.warning(
"episode %d: every VLM subtask was off-vocabulary "
"(%d spans); snapping to closest canonical label anyway "
"(check meta/canonical_vocabulary.json for missing phases)",
record.episode_index,
len(raw),
)
for span in raw:
mapped = self._canonicalize_subtask(span["text"], force=True)
if mapped:
cleaned.append({**span, "text": mapped})
elif len(cleaned) < len(raw):
logger.info(
"episode %d: %d/%d subtasks survived canonicalisation; "
"the rest were off-vocabulary",
record.episode_index,
len(cleaned),
len(raw),
)
cleaned.sort(key=lambda s: s["start"]) cleaned.sort(key=lambda s: s["start"])
if self.vocabulary is not None and self.vocabulary.subtasks and not cleaned:
logger.warning(
"episode %d: every VLM subtask was off-vocab even after retry — "
"episode left empty (extend meta/canonical_vocabulary.json to "
"cover the missing phase)",
record.episode_index,
)
return cleaned return cleaned
# ------------------------------------------------------------------ # ------------------------------------------------------------------
@@ -420,70 +413,93 @@ class PlanSubtasksMemoryModule:
f"{bullets}\n\n" f"{bullets}\n\n"
) )
_CANONICALIZE_JACCARD_FLOOR: float = 0.25 _NORMALIZE_STRIP_TOKENS: frozenset[str] = frozenset({"the", "a", "an"})
_CANONICALIZE_IGNORE_TOKENS: frozenset[str] = frozenset(
{"the", "a", "an", "to", "into", "from", "of", "on", "over", "at"}
)
def _canonicalize_subtask(self, text: str, *, force: bool = False) -> str: def _canonicalize_subtask(self, text: str) -> str:
"""Snap ``text`` to the closest canonical subtask string. """Validate ``text`` against the canonical vocabulary; no fuzzy snap.
Without a vocabulary, the original text passes through. With a Without a vocabulary, the original text passes through. With a
vocabulary, an exact case-insensitive match wins; failing that, vocabulary, accept the span only if its normalised form (lower-
the best Jaccard overlap on the word set is used as a tolerant cased, articles stripped, whitespace collapsed) matches a
fuzzy match (handles articles / minor reorderings). canonical entry exactly — the canonical wording is returned so
the supervised string is byte-identical across episodes.
Behaviour at the Jaccard floor depends on ``force``: Off-vocab spans are dropped (empty string). Upstream
- ``force=False`` (default): below ``_CANONICALIZE_JACCARD_FLOOR`` ``_generate_subtasks`` triggers a targeted retry before reaching
the subtask is dropped. ``_generate_subtasks`` runs this first the drop path; this function never snaps or warps a span into
to filter genuine off-topic hallucinations. a different label.
- ``force=True``: always snap, no floor. ``_generate_subtasks``
uses this in a second pass when the first pass would otherwise
empty the episode — a slightly-wrong canonical label is still
closer to the right phase than no subtask at all, which makes
the whole episode invisible to the downstream policy.
""" """
if self.vocabulary is None or not self.vocabulary.subtasks: if self.vocabulary is None or not self.vocabulary.subtasks:
return text.strip() return text.strip()
candidates = self.vocabulary.subtasks normalised = self._normalize(text)
cleaned = text.strip() if not normalised:
lowered = cleaned.lower() return ""
for candidate in candidates: for candidate in self.vocabulary.subtasks:
if candidate.lower() == lowered: if self._normalize(candidate) == normalised:
return candidate return candidate
# Jaccard fallback: token-set overlap, drop articles + adverbs.
words = {
w for w in lowered.replace(",", " ").split()
if w and w not in self._CANONICALIZE_IGNORE_TOKENS
}
if not words:
return "" return ""
best: tuple[float, str] | None = None
for candidate in candidates: @classmethod
cand_words = { def _normalize(cls, text: str) -> str:
w for w in candidate.lower().replace(",", " ").split() """Lowercase, strip articles, collapse whitespace, drop punctuation."""
if w and w not in self._CANONICALIZE_IGNORE_TOKENS words = [
} w.strip(".,:;\"'!?()")
if not cand_words: for w in text.lower().replace(",", " ").split()
]
return " ".join(w for w in words if w and w not in cls._NORMALIZE_STRIP_TOKENS)
def _invalid_subtasks(self, spans: list[dict[str, Any]]) -> list[str]:
"""Return the unique off-vocab subtask strings the VLM produced."""
seen: list[str] = []
for span in spans:
text = str((span or {}).get("text") or "").strip()
if not text:
continue continue
inter = len(words & cand_words) if self._canonicalize_subtask(text):
union = len(words | cand_words) continue
score = inter / union if union else 0.0 if text not in seen:
if best is None or score > best[0]: seen.append(text)
best = (score, candidate) return seen
if best is None:
return "" def _build_subtask_retry_message(
if not force and best[0] < self._CANONICALIZE_JACCARD_FLOOR: self, original_messages: list[dict[str, Any]], invalid: list[str]
logger.info( ) -> list[dict[str, Any]]:
"subtask %r dropped — best canonical match %r scored %.2f " """Compose a one-shot correction prompt naming the off-vocab strings."""
"(< %.2f Jaccard floor)", assert self.vocabulary is not None
cleaned, canonical = "\n".join(f"- {s}" for s in self.vocabulary.subtasks)
best[1], invalid_list = "\n".join(f"- {s!r}" for s in invalid)
best[0], correction = (
self._CANONICALIZE_JACCARD_FLOOR, "Your previous response included subtask labels that are NOT in "
"the canonical vocabulary:\n"
f"{invalid_list}\n\n"
"Re-emit the same segmentation (same number of spans, same start/end "
"timestamps where they were valid) but replace every off-vocab "
"label with the EXACT canonical string for that phase, copied "
"verbatim from this list:\n"
f"{canonical}\n\n"
"Strict rules:\n"
"- Output strings must be byte-for-byte identical to entries above.\n"
"- No articles, no adverbs, no extra words.\n"
"- If a phase truly has no canonical match, omit that span entirely.\n"
"Return the same JSON shape as before."
) )
return "" # Append the correction as an additional user turn; the model
return best[1] # sees the original prompt + its prior output is implied by the
# conversation context (the VLM client is stateless, so we
# re-send the original content plus this correction).
retry_messages = [
{
"role": m.get("role", "user"),
"content": (
m.get("content")
if isinstance(m.get("content"), str)
else list(m.get("content") or [])
),
}
for m in original_messages
]
retry_messages.append({"role": "user", "content": correction})
return retry_messages
def _generate_plan( def _generate_plan(
self, self,
+108 -56
View File
@@ -182,59 +182,17 @@ def test_plan_module_inlines_vocab_into_subtask_prompt(
assert any("grasp blue cube" in t for t in captured) assert any("grasp blue cube" in t for t in captured)
def test_plan_module_canonicalizes_paraphrased_subtask( def test_plan_module_accepts_article_only_difference(
fixture_dataset_root: Path, tmp_path: Path fixture_dataset_root: Path, tmp_path: Path
) -> None: ) -> None:
"""Off-vocab paraphrase with high token overlap snaps to canonical form.""" """Articles like 'the'/'a'/'an' are stripped during validation."""
from lerobot.annotations.steerable_pipeline.vlm_client import StubVlmClient from lerobot.annotations.steerable_pipeline.vlm_client import StubVlmClient
def responder(_messages): def responder(_messages):
return { return {
"subtasks": [ "subtasks": [
# paraphrase of "grasp blue cube" — overlapping tokens # Same canonical phrase modulo "the" — should be accepted.
{"text": "grasp the blue cube", "start": 0.0, "end": 0.4}, {"text": "grasp the blue cube", "start": 0.0, "end": 0.4},
# paraphrase of "place blue cube in box" — high overlap
{"text": "place the blue cube into the box", "start": 0.4, "end": 0.9},
]
}
vlm = StubVlmClient(responder=responder)
vocab = Vocabulary(subtasks=_CANONICAL_SUBTASKS, memory_milestones=_CANONICAL_MEMORY)
module = PlanSubtasksMemoryModule(
vlm=vlm,
config=PlanConfig(n_task_rephrasings=0),
vocabulary=vocab,
)
record = next(iter_episodes(fixture_dataset_root))
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
module.run_episode(record, staging)
rows = staging.read("plan")
subtask_texts = [r["content"] for r in rows if r["style"] == "subtask"]
# Both paraphrases snapped onto canonical strings.
assert subtask_texts == ["grasp blue cube", "place blue cube in box"]
def test_plan_module_drops_off_vocab_subtask(
fixture_dataset_root: Path, tmp_path: Path
) -> None:
"""A subtask with low overlap to every canonical label is dropped.
Drop only kicks in when *at least one* other subtask survives — if
every span would be dropped the episode would come out empty, so
``_generate_subtasks`` falls back to snap-without-floor; that path
is exercised by ``test_plan_module_snaps_when_all_off_vocab``.
"""
from lerobot.annotations.steerable_pipeline.vlm_client import StubVlmClient
def responder(_messages):
return {
"subtasks": [
# in-vocab — keeps the episode non-empty so the floor
# is allowed to drop the next span.
{"text": "grasp blue cube", "start": 0.0, "end": 0.4},
# off-vocab hallucination — no token overlap above the
# Jaccard floor; should be dropped.
{"text": "perform a fancy macarena dance", "start": 0.4, "end": 0.9},
] ]
} }
@@ -253,18 +211,114 @@ def test_plan_module_drops_off_vocab_subtask(
assert subtask_texts == ["grasp blue cube"] assert subtask_texts == ["grasp blue cube"]
def test_plan_module_snaps_when_all_off_vocab( def test_plan_module_retries_when_subtask_off_vocab(
fixture_dataset_root: Path, tmp_path: Path fixture_dataset_root: Path, tmp_path: Path
) -> None: ) -> None:
"""All-off-vocab spans snap to nearest canonical instead of emptying the episode.""" """One-shot retry replaces an off-vocab paraphrase with the canonical form."""
from lerobot.annotations.steerable_pipeline.vlm_client import StubVlmClient
call_count = {"n": 0}
def responder(messages):
call_count["n"] += 1
# First call: returns an off-vocab paraphrase.
if call_count["n"] == 1:
return {
"subtasks": [
# paraphrase, not in vocab
{"text": "pick up blue cube", "start": 0.0, "end": 0.4},
]
}
# Second call (the retry): should contain the correction prompt;
# respond with the canonical phrase exactly.
last_user_text = ""
for message in messages:
content = message.get("content")
if isinstance(content, str):
last_user_text = content
elif isinstance(content, list):
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
last_user_text = block.get("text", "")
assert "NOT in the canonical vocabulary" in last_user_text
return {
"subtasks": [
{"text": "grasp blue cube", "start": 0.0, "end": 0.4},
]
}
vlm = StubVlmClient(responder=responder)
vocab = Vocabulary(subtasks=_CANONICAL_SUBTASKS, memory_milestones=_CANONICAL_MEMORY)
module = PlanSubtasksMemoryModule(
vlm=vlm,
config=PlanConfig(n_task_rephrasings=0),
vocabulary=vocab,
)
record = next(iter_episodes(fixture_dataset_root))
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
module.run_episode(record, staging)
rows = staging.read("plan")
subtask_texts = [r["content"] for r in rows if r["style"] == "subtask"]
assert subtask_texts == ["grasp blue cube"]
# The retry must have fired exactly once.
assert call_count["n"] == 2
def test_plan_module_drops_off_vocab_subtask_after_retry(
fixture_dataset_root: Path, tmp_path: Path
) -> None:
"""If the VLM stays off-vocab even after the retry, the bad span is dropped."""
from lerobot.annotations.steerable_pipeline.vlm_client import StubVlmClient
call_count = {"n": 0}
def responder(_messages):
call_count["n"] += 1
# Both calls return the same off-vocab span — the model can't
# be corrected. The second call also returns one in-vocab span
# so the episode isn't empty; this lets us check that the
# off-vocab span is dropped without affecting the in-vocab one.
if call_count["n"] == 1:
return {
"subtasks": [
{"text": "perform a fancy macarena dance", "start": 0.0, "end": 0.4},
{"text": "grasp blue cube", "start": 0.4, "end": 0.9},
]
}
return {
"subtasks": [
{"text": "perform a fancy macarena dance", "start": 0.0, "end": 0.4},
{"text": "grasp blue cube", "start": 0.4, "end": 0.9},
]
}
vlm = StubVlmClient(responder=responder)
vocab = Vocabulary(subtasks=_CANONICAL_SUBTASKS, memory_milestones=_CANONICAL_MEMORY)
module = PlanSubtasksMemoryModule(
vlm=vlm,
config=PlanConfig(n_task_rephrasings=0),
vocabulary=vocab,
)
record = next(iter_episodes(fixture_dataset_root))
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
module.run_episode(record, staging)
rows = staging.read("plan")
subtask_texts = [r["content"] for r in rows if r["style"] == "subtask"]
# Retry fired exactly once; bad span dropped, good span kept.
assert call_count["n"] == 2
assert subtask_texts == ["grasp blue cube"]
def test_plan_module_empty_when_all_off_vocab_after_retry(
fixture_dataset_root: Path, tmp_path: Path
) -> None:
"""All-off-vocab spans → episode comes out empty (no silent fuzzy snap)."""
from lerobot.annotations.steerable_pipeline.vlm_client import StubVlmClient from lerobot.annotations.steerable_pipeline.vlm_client import StubVlmClient
def responder(_messages): def responder(_messages):
# Returns the same off-vocab spans on both attempts.
return { return {
"subtasks": [ "subtasks": [
# Both off-vocab — would normally be dropped. The
# fallback should snap each to its best canonical match
# rather than leave the episode with no subtasks at all.
{"text": "make a smoothie", "start": 0.0, "end": 0.4}, {"text": "make a smoothie", "start": 0.0, "end": 0.4},
{"text": "consult the wizard", "start": 0.4, "end": 0.9}, {"text": "consult the wizard", "start": 0.4, "end": 0.9},
] ]
@@ -282,12 +336,10 @@ def test_plan_module_snaps_when_all_off_vocab(
module.run_episode(record, staging) module.run_episode(record, staging)
rows = staging.read("plan") rows = staging.read("plan")
subtask_texts = [r["content"] for r in rows if r["style"] == "subtask"] subtask_texts = [r["content"] for r in rows if r["style"] == "subtask"]
# Two off-vocab spans → two canonical subtasks (snapped to nearest # No subtask gets fabricated — better to leave the episode empty
# by Jaccard with no floor). The exact canonical choice doesn't # so the operator notices the vocabulary gap than to silently
# matter — only that the episode came out with subtasks rather # warp the labels.
# than empty. assert subtask_texts == []
assert len(subtask_texts) == 2
assert all(s in _CANONICAL_SUBTASKS for s in subtask_texts)
def test_plan_module_without_vocab_passes_through( def test_plan_module_without_vocab_passes_through(