diff --git a/src/lerobot/annotations/steerable_pipeline/modules/plan_subtasks_memory.py b/src/lerobot/annotations/steerable_pipeline/modules/plan_subtasks_memory.py index 218d98ad1..ed226bd02 100644 --- a/src/lerobot/annotations/steerable_pipeline/modules/plan_subtasks_memory.py +++ b/src/lerobot/annotations/steerable_pipeline/modules/plan_subtasks_memory.py @@ -328,7 +328,7 @@ class PlanSubtasksMemoryModule: # clamp to [t0, t_last] and sort t0 = record.frame_timestamps[0] t_last = record.frame_timestamps[-1] - cleaned: list[dict[str, Any]] = [] + raw: list[dict[str, Any]] = [] for span in spans: try: start = float(span["start"]) @@ -340,12 +340,45 @@ class PlanSubtasksMemoryModule: end = max(t0, min(end, t_last)) if end < start: start, end = end, start - if not text: - continue - text = self._canonicalize_subtask(text) - if not text: - continue - cleaned.append({"text": text, "start": start, "end": end}) + if text: + raw.append({"text": text, "start": start, "end": end}) + + # Without a vocabulary, free-form spans pass through unchanged. + if self.vocabulary is None or not self.vocabulary.subtasks: + raw.sort(key=lambda s: s["start"]) + return raw + + # With a vocabulary, snap each span to the closest canonical + # label. Two-pass: first try the normal Jaccard floor (drops + # off-topic hallucinations); if that leaves the episode with + # zero subtasks, fall back to snap-without-floor so the episode + # is never silently emptied — a wrong canonical label is still + # closer to the right phase than nothing at all. + cleaned: list[dict[str, Any]] = [] + for span in raw: + mapped = self._canonicalize_subtask(span["text"]) + if mapped: + cleaned.append({**span, "text": mapped}) + if not cleaned and raw: + logger.warning( + "episode %d: every VLM subtask was off-vocabulary " + "(%d spans); snapping to closest canonical label anyway " + "(check meta/canonical_vocabulary.json for missing phases)", + record.episode_index, + len(raw), + ) + for span in raw: + mapped = self._canonicalize_subtask(span["text"], force=True) + if mapped: + cleaned.append({**span, "text": mapped}) + elif len(cleaned) < len(raw): + logger.info( + "episode %d: %d/%d subtasks survived canonicalisation; " + "the rest were off-vocabulary", + record.episode_index, + len(cleaned), + len(raw), + ) cleaned.sort(key=lambda s: s["start"]) return cleaned @@ -387,15 +420,28 @@ class PlanSubtasksMemoryModule: f"{bullets}\n\n" ) - def _canonicalize_subtask(self, text: str) -> str: - """Snap ``text`` to the closest canonical subtask string, or drop it. + _CANONICALIZE_JACCARD_FLOOR: float = 0.25 + _CANONICALIZE_IGNORE_TOKENS: frozenset[str] = frozenset( + {"the", "a", "an", "to", "into", "from", "of", "on", "over", "at"} + ) + + def _canonicalize_subtask(self, text: str, *, force: bool = False) -> str: + """Snap ``text`` to the closest canonical subtask string. Without a vocabulary, the original text passes through. With a vocabulary, an exact case-insensitive match wins; failing that, the best Jaccard overlap on the word set is used as a tolerant - fuzzy match (handles articles / minor reorderings). If nothing - clears the floor, the subtask is dropped — better to skip a - phase than to feed the action expert an off-distribution string. + fuzzy match (handles articles / minor reorderings). + + Behaviour at the Jaccard floor depends on ``force``: + - ``force=False`` (default): below ``_CANONICALIZE_JACCARD_FLOOR`` + the subtask is dropped. ``_generate_subtasks`` runs this first + to filter genuine off-topic hallucinations. + - ``force=True``: always snap, no floor. ``_generate_subtasks`` + uses this in a second pass when the first pass would otherwise + empty the episode — a slightly-wrong canonical label is still + closer to the right phase than no subtask at all, which makes + the whole episode invisible to the downstream policy. """ if self.vocabulary is None or not self.vocabulary.subtasks: return text.strip() @@ -406,14 +452,17 @@ class PlanSubtasksMemoryModule: if candidate.lower() == lowered: return candidate # Jaccard fallback: token-set overlap, drop articles + adverbs. - ignore = {"the", "a", "an", "to", "into", "from", "of", "on", "over", "at"} - words = {w for w in lowered.replace(",", " ").split() if w and w not in ignore} + words = { + w for w in lowered.replace(",", " ").split() + if w and w not in self._CANONICALIZE_IGNORE_TOKENS + } if not words: return "" best: tuple[float, str] | None = None for candidate in candidates: cand_words = { - w for w in candidate.lower().replace(",", " ").split() if w and w not in ignore + w for w in candidate.lower().replace(",", " ").split() + if w and w not in self._CANONICALIZE_IGNORE_TOKENS } if not cand_words: continue @@ -422,14 +471,16 @@ class PlanSubtasksMemoryModule: score = inter / union if union else 0.0 if best is None or score > best[0]: best = (score, candidate) - # Floor: require at least ~half the tokens to overlap. Below that - # the VLM is hallucinating a novel phrase; drop rather than warp - # it into something semantically wrong. - if best is None or best[0] < 0.5: - logger.warning( - "subtask %r did not match any canonical label (best=%s) — dropping", + if best is None: + return "" + if not force and best[0] < self._CANONICALIZE_JACCARD_FLOOR: + logger.info( + "subtask %r dropped — best canonical match %r scored %.2f " + "(< %.2f Jaccard floor)", cleaned, - best, + best[1], + best[0], + self._CANONICALIZE_JACCARD_FLOOR, ) return "" return best[1] diff --git a/tests/annotations/test_vocabulary.py b/tests/annotations/test_vocabulary.py index a9f080a16..20d22a50d 100644 --- a/tests/annotations/test_vocabulary.py +++ b/tests/annotations/test_vocabulary.py @@ -217,13 +217,20 @@ def test_plan_module_canonicalizes_paraphrased_subtask( def test_plan_module_drops_off_vocab_subtask( fixture_dataset_root: Path, tmp_path: Path ) -> None: - """A subtask with low overlap to every canonical label is dropped.""" + """A subtask with low overlap to every canonical label is dropped. + + Drop only kicks in when *at least one* other subtask survives — if + every span would be dropped the episode would come out empty, so + ``_generate_subtasks`` falls back to snap-without-floor; that path + is exercised by ``test_plan_module_snaps_when_all_off_vocab``. + """ from lerobot.annotations.steerable_pipeline.vlm_client import StubVlmClient def responder(_messages): return { "subtasks": [ - # in-vocab + # in-vocab — keeps the episode non-empty so the floor + # is allowed to drop the next span. {"text": "grasp blue cube", "start": 0.0, "end": 0.4}, # off-vocab hallucination — no token overlap above the # Jaccard floor; should be dropped. @@ -246,6 +253,43 @@ def test_plan_module_drops_off_vocab_subtask( assert subtask_texts == ["grasp blue cube"] +def test_plan_module_snaps_when_all_off_vocab( + fixture_dataset_root: Path, tmp_path: Path +) -> None: + """All-off-vocab spans snap to nearest canonical instead of emptying the episode.""" + from lerobot.annotations.steerable_pipeline.vlm_client import StubVlmClient + + def responder(_messages): + return { + "subtasks": [ + # Both off-vocab — would normally be dropped. The + # fallback should snap each to its best canonical match + # rather than leave the episode with no subtasks at all. + {"text": "make a smoothie", "start": 0.0, "end": 0.4}, + {"text": "consult the wizard", "start": 0.4, "end": 0.9}, + ] + } + + vlm = StubVlmClient(responder=responder) + vocab = Vocabulary(subtasks=_CANONICAL_SUBTASKS, memory_milestones=_CANONICAL_MEMORY) + module = PlanSubtasksMemoryModule( + vlm=vlm, + config=PlanConfig(n_task_rephrasings=0), + vocabulary=vocab, + ) + record = next(iter_episodes(fixture_dataset_root)) + staging = EpisodeStaging(tmp_path / "stage", record.episode_index) + module.run_episode(record, staging) + rows = staging.read("plan") + subtask_texts = [r["content"] for r in rows if r["style"] == "subtask"] + # Two off-vocab spans → two canonical subtasks (snapped to nearest + # by Jaccard with no floor). The exact canonical choice doesn't + # matter — only that the episode came out with subtasks rather + # than empty. + assert len(subtask_texts) == 2 + assert all(s in _CANONICAL_SUBTASKS for s in subtask_texts) + + def test_plan_module_without_vocab_passes_through( fixture_dataset_root: Path, tmp_path: Path ) -> None: