From f60babc946aafe3aa215c57ab2d636e5945da1ae Mon Sep 17 00:00:00 2001 From: Pepijn Date: Thu, 14 May 2026 13:39:42 +0200 Subject: [PATCH] refactor(inference): subtask = head of plan, drop AR generation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replaced the LM-head subtask generator with a deterministic plan-walker. The plan is already a numbered list of still-todo subtasks (produced by ``plan_generation``); the current subtask is just the head of that list, popped one position per chunk-dispatch boundary. Why --- The previous design called ``select_message`` every ~1 Hz to auto-regressively generate the current subtask string. On a small training set with a chat-pretrained backbone, the LM head collapses into n-gram loops under greedy decoding — "the robot arm extends and retracts and retracts and retracts from the beige surface and retracts from the surface" — which then poisons the action expert's language conditioning, because the same garbage string is what gets rendered into its prefix. Repetition penalty / no_repeat_ngram_size help but are bandaids. Walking the plan deterministically eliminates the failure mode at the root: the action expert always sees a real, recipe-aligned subtask phrase. Inference is also interpretable now — every robot motion traces back to a specific plan line — and we save the per- chunk ``select_message`` round trip (~2 s on MPS). The LM head is still trained (``high_level_subtask`` text-CE) and still used at runtime for one-shot plan generation, interjection replanning, and VQA. We just don't sit on the AR subtask loop. Advance trigger: ``actions_dispatched`` increment — each new chunk the action expert produced and dispatched advances ``plan_index`` by one. Clamps to the last plan line so the robot keeps executing the final subtask if it runs off the end instead of going silent. Falls back to the task string when no plan is set yet (e.g. before ``plan_generation`` has fired this episode). Added ``_parse_plan_lines`` helper for stripping the ``"N. "`` prefix off plan items. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../policies/smolvla2/inference/steps.py | 179 ++++++++---------- 1 file changed, 80 insertions(+), 99 deletions(-) diff --git a/src/lerobot/policies/smolvla2/inference/steps.py b/src/lerobot/policies/smolvla2/inference/steps.py index ab86afdc5..92449895c 100644 --- a/src/lerobot/policies/smolvla2/inference/steps.py +++ b/src/lerobot/policies/smolvla2/inference/steps.py @@ -346,118 +346,77 @@ def _strip_recipe_keys(m: dict[str, Any]) -> dict[str, Any]: @dataclass class HighLevelSubtaskFwd(InferenceStep): - """At ~1 Hz, ask the policy for the next subtask. + """Deterministic plan walker — current subtask is the head of the plan. - Mirrors the ``high_level_subtask`` recipe layout exactly: + Rationale + --------- + The training-time ``plan_generation`` recipe summarises subtasks as + a numbered list of still-todo items. The ``high_level_subtask`` + recipe supervises subtask text auto-regressively. Empirically the + LM head's AR generation collapses on small datasets and produces + repetitive / off-distribution subtasks + ("…extends and retracts and retracts…"), which corrupts the + action expert's conditioning. - user: "${task}\\nPlan: ${plan}\\nMemory: ${memory}" - user: "Current subtask: ${subtask}" (if subtask present) - ↓ generate ↓ - assistant: + Instead of trusting the LM head's free generation, walk the plan: + + * The plan is a string ``"1. \\n2. \\n…"``. + * The current subtask is the line at ``state["plan_index"]``. + * After each chunk dispatch (action queue empties), advance the + index by one. + + This is interpretable, can't repetition-collapse, and stays in + lockstep with the deterministic plan that ``plan_generation`` + emits. + + The LM head is still trained and still used for plan generation + (one-shot at episode start), VQA, and interjection replanning — + just not for the always-on per-chunk subtask loop. """ policy: Any = None - observation_provider: Any = None - """Same shape as ``LowLevelForward.observation_provider``. When - set, the resulting observation is merged into ``select_message``'s - batch so text generation runs against real video + state.""" - + observation_provider: Any = None # kept for signature compatibility trigger: Trigger = field(default_factory=lambda: HzTrigger(hz=1.0)) def run(self, state: dict[str, Any]) -> dict[str, Any] | None: - if self.policy is None or not state.get("task"): + if not state.get("task"): return None - # Gate to chunk boundaries: only generate a fresh subtask when - # the action queue is empty (i.e. right before LowLevelForward - # refreshes the chunk). ``select_message`` takes ~2 s on MPS, - # and running it every loop iteration starves DispatchAction - # at ctrl_hz=30 — the queue drains at ~0.4 actions/sec instead - # of 30/sec and the robot barely moves. Tying it to the same - # "queue empty" condition as the chunk refresh produces a - # clean sense → think → act cycle. + # Only advance at chunk boundaries (action queue empty) — same + # gating as the original AR path, so subtask transitions stay + # aligned with the dispatch cycle. queue = state.get("action_queue") or [] if len(queue) > 0: return None - ctx = _msgs_for_subtask(state) - observation = _maybe_observation(self.observation_provider) - # Default: greedy argmax, no min_new_tokens, no special-token - # suppression — matches training. Operator can override via - # ``--text_min_new_tokens=N --text_temperature=T --text_top_p=P`` - # on the CLI; useful for under-trained checkpoints whose LM - # head still favours EOS at position 0 (pre-trained chat - # backbone's short-turn prior hasn't been fully overridden - # by the fine-tuning supervision yet). - msg = _generate_with_policy( - self.policy, - ctx, - observation=observation, - state=state, - label="subtask gen", - min_new_tokens=int(state.get("text_gen_min_new_tokens") or 0), - temperature=float(state.get("text_gen_temperature") or 0.0), - top_p=float(state.get("text_gen_top_p") or 1.0), - repetition_penalty=float(state.get("text_gen_repetition_penalty") or 1.0), - no_repeat_ngram_size=int(state.get("text_gen_no_repeat_ngram_size") or 0), - ) - # Diagnostics: surface what the model is *actually* producing - # at chunk boundaries, even when the output gets rejected or - # repeats. Memorisation collapse looks like "same accepted - # subtask N times in a row" or "gibberish_count rising while - # current_subtask is stuck". The state panel renders these. - state["last_subtask_raw"] = msg or "" - # Persistent empty completion is its own failure mode (model - # immediately EOS-es from the chat-template generation - # prompt) — surface it once every N occurrences so the - # operator can distinguish "generation failing silently" - # from "generating fine but filter rejecting". - if not msg: - empties = state.get("subtask_empty_count", 0) + 1 - state["subtask_empty_count"] = empties - if empties == 1 or empties % 5 == 0: - debug = getattr(self.policy, "_last_select_message_debug", "") or "" - if debug: - push_log( - state, - f" [info] subtask gen empty (×{empties}); {debug}", - ) - else: - push_log( - state, - f" [info] subtask gen returned empty (×{empties}) — " - "no tokens generated (head EOS-ing before any " - "non-special token).", - ) - if msg and _looks_like_gibberish(msg): - # Bump a counter so the operator can see the model is - # struggling without spamming the log every tick. A first - # rejection still logs once so the failure is visible. - count = state.get("subtask_gibberish_count", 0) + 1 - state["subtask_gibberish_count"] = count - if count == 1 or count % 30 == 0: - push_log( - state, - f" [info] subtask gen rejected (gibberish ×{count}): {msg[:60]!r}", - ) + + plan_lines = _parse_plan_lines(state.get("current_plan") or "") + if not plan_lines: + # No plan available yet (e.g. plan_generation hasn't fired + # on this episode). Fall back to the task string so the + # action expert at least has *something* to condition on. + fallback = state.get("task") or "" + if fallback and state.get("current_subtask") != fallback: + set_if_changed(state, "current_subtask", fallback, label="subtask") return None - if msg: - changed = set_if_changed(state, "current_subtask", msg, label="subtask") - if changed: - # Subtask change is a downstream trigger. - state.setdefault("events_this_tick", []).append("subtask_change") - state["subtask_repeat_count"] = 0 - else: - # Same accepted string regenerated — memorisation tell. - # Once this counter climbs past a few, you're seeing - # the model unable to move past the current subtask - # despite the chunk having drained (visual scene may - # have changed but the LM is replaying training - # tokens). - state["subtask_repeat_count"] = ( - state.get("subtask_repeat_count", 0) + 1 - ) - # Silently skip empty completions — common when the model - # warms up or generates only EOS; logging it every tick at - # ctrl_hz is just noise. + + # ``plan_index`` advances by 1 each chunk boundary. Tracked via + # ``actions_dispatched`` (set by ``DispatchAction``): we advance + # only when at least one new action was dispatched since the + # previous walker tick. Clamps to the last line so the robot + # keeps executing the final subtask instead of going silent. + plan_index = int(state.get("plan_index", 0) or 0) + dispatched = int(state.get("actions_dispatched", 0) or 0) + last_dispatched = state.get("_subtask_walker_last_dispatched") + if last_dispatched is not None and dispatched > int(last_dispatched): + plan_index = min(plan_index + 1, len(plan_lines) - 1) + state["plan_index"] = plan_index + state["_subtask_walker_last_dispatched"] = dispatched + + new_subtask = plan_lines[plan_index] + changed = set_if_changed(state, "current_subtask", new_subtask, label="subtask") + if changed: + state.setdefault("events_this_tick", []).append("subtask_change") + # Surface diagnostics in the panel format the runtime expects. + state["last_subtask_raw"] = new_subtask return None @@ -663,6 +622,28 @@ class DispatchToolCalls(InferenceStep): # --------------------------------------------------------------------------- +_PLAN_LINE_RE = re.compile(r"^\s*\d+[.)]\s*(.+?)\s*$") + + +def _parse_plan_lines(plan_text: str) -> list[str]: + """Split a numbered-list plan string into its individual subtask lines. + + Accepts both ``"1. foo"`` and ``"1) foo"`` formatting. Falls back to + splitting on newlines for non-numbered plans. Empty lines are + dropped. Returns the bare subtask text (number prefix stripped). + """ + if not plan_text: + return [] + out: list[str] = [] + for raw in plan_text.splitlines(): + line = raw.strip() + if not line: + continue + m = _PLAN_LINE_RE.match(line) + out.append(m.group(1).strip() if m else line) + return out + + def _looks_like_gibberish(text: str) -> bool: """Heuristically detect generation that's clearly off the rails.