refactor(inference): subtask = head of plan, drop AR generation

Replaced the LM-head subtask generator with a deterministic
plan-walker. The plan is already a numbered list of still-todo
subtasks (produced by ``plan_generation``); the current subtask is
just the head of that list, popped one position per chunk-dispatch
boundary.

Why
---
The previous design called ``select_message`` every ~1 Hz to
auto-regressively generate the current subtask string. On a small
training set with a chat-pretrained backbone, the LM head collapses
into n-gram loops under greedy decoding —

  "the robot arm extends and retracts and retracts and retracts
   from the beige surface and retracts from the surface"

— which then poisons the action expert's language conditioning,
because the same garbage string is what gets rendered into its
prefix. Repetition penalty / no_repeat_ngram_size help but are
bandaids.

Walking the plan deterministically eliminates the failure mode at
the root: the action expert always sees a real, recipe-aligned
subtask phrase. Inference is also interpretable now — every robot
motion traces back to a specific plan line — and we save the per-
chunk ``select_message`` round trip (~2 s on MPS).

The LM head is still trained (``high_level_subtask`` text-CE) and
still used at runtime for one-shot plan generation, interjection
replanning, and VQA. We just don't sit on the AR subtask loop.

Advance trigger: ``actions_dispatched`` increment — each new chunk
the action expert produced and dispatched advances ``plan_index``
by one. Clamps to the last plan line so the robot keeps executing
the final subtask if it runs off the end instead of going silent.
Falls back to the task string when no plan is set yet (e.g. before
``plan_generation`` has fired this episode).

Added ``_parse_plan_lines`` helper for stripping the ``"N. "``
prefix off plan items.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-14 13:39:42 +02:00
parent db2972fb6c
commit f60babc946
@@ -346,118 +346,77 @@ def _strip_recipe_keys(m: dict[str, Any]) -> dict[str, Any]:
@dataclass @dataclass
class HighLevelSubtaskFwd(InferenceStep): class HighLevelSubtaskFwd(InferenceStep):
"""At ~1 Hz, ask the policy for the next subtask. """Deterministic plan walker — current subtask is the head of the plan.
Mirrors the ``high_level_subtask`` recipe layout exactly: Rationale
---------
The training-time ``plan_generation`` recipe summarises subtasks as
a numbered list of still-todo items. The ``high_level_subtask``
recipe supervises subtask text auto-regressively. Empirically the
LM head's AR generation collapses on small datasets and produces
repetitive / off-distribution subtasks
("…extends and retracts and retracts…"), which corrupts the
action expert's conditioning.
user: "${task}\\nPlan: ${plan}\\nMemory: ${memory}" Instead of trusting the LM head's free generation, walk the plan:
user: "Current subtask: ${subtask}" (if subtask present)
↓ generate ↓ * The plan is a string ``"1. <subtask>\\n2. <subtask>\\n…"``.
assistant: <next subtask> * The current subtask is the line at ``state["plan_index"]``.
* After each chunk dispatch (action queue empties), advance the
index by one.
This is interpretable, can't repetition-collapse, and stays in
lockstep with the deterministic plan that ``plan_generation``
emits.
The LM head is still trained and still used for plan generation
(one-shot at episode start), VQA, and interjection replanning —
just not for the always-on per-chunk subtask loop.
""" """
policy: Any = None policy: Any = None
observation_provider: Any = None observation_provider: Any = None # kept for signature compatibility
"""Same shape as ``LowLevelForward.observation_provider``. When
set, the resulting observation is merged into ``select_message``'s
batch so text generation runs against real video + state."""
trigger: Trigger = field(default_factory=lambda: HzTrigger(hz=1.0)) trigger: Trigger = field(default_factory=lambda: HzTrigger(hz=1.0))
def run(self, state: dict[str, Any]) -> dict[str, Any] | None: def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
if self.policy is None or not state.get("task"): if not state.get("task"):
return None return None
# Gate to chunk boundaries: only generate a fresh subtask when # Only advance at chunk boundaries (action queue empty) — same
# the action queue is empty (i.e. right before LowLevelForward # gating as the original AR path, so subtask transitions stay
# refreshes the chunk). ``select_message`` takes ~2 s on MPS, # aligned with the dispatch cycle.
# and running it every loop iteration starves DispatchAction
# at ctrl_hz=30 — the queue drains at ~0.4 actions/sec instead
# of 30/sec and the robot barely moves. Tying it to the same
# "queue empty" condition as the chunk refresh produces a
# clean sense → think → act cycle.
queue = state.get("action_queue") or [] queue = state.get("action_queue") or []
if len(queue) > 0: if len(queue) > 0:
return None return None
ctx = _msgs_for_subtask(state)
observation = _maybe_observation(self.observation_provider) plan_lines = _parse_plan_lines(state.get("current_plan") or "")
# Default: greedy argmax, no min_new_tokens, no special-token if not plan_lines:
# suppression — matches training. Operator can override via # No plan available yet (e.g. plan_generation hasn't fired
# ``--text_min_new_tokens=N --text_temperature=T --text_top_p=P`` # on this episode). Fall back to the task string so the
# on the CLI; useful for under-trained checkpoints whose LM # action expert at least has *something* to condition on.
# head still favours EOS at position 0 (pre-trained chat fallback = state.get("task") or ""
# backbone's short-turn prior hasn't been fully overridden if fallback and state.get("current_subtask") != fallback:
# by the fine-tuning supervision yet). set_if_changed(state, "current_subtask", fallback, label="subtask")
msg = _generate_with_policy(
self.policy,
ctx,
observation=observation,
state=state,
label="subtask gen",
min_new_tokens=int(state.get("text_gen_min_new_tokens") or 0),
temperature=float(state.get("text_gen_temperature") or 0.0),
top_p=float(state.get("text_gen_top_p") or 1.0),
repetition_penalty=float(state.get("text_gen_repetition_penalty") or 1.0),
no_repeat_ngram_size=int(state.get("text_gen_no_repeat_ngram_size") or 0),
)
# Diagnostics: surface what the model is *actually* producing
# at chunk boundaries, even when the output gets rejected or
# repeats. Memorisation collapse looks like "same accepted
# subtask N times in a row" or "gibberish_count rising while
# current_subtask is stuck". The state panel renders these.
state["last_subtask_raw"] = msg or ""
# Persistent empty completion is its own failure mode (model
# immediately EOS-es from the chat-template generation
# prompt) — surface it once every N occurrences so the
# operator can distinguish "generation failing silently"
# from "generating fine but filter rejecting".
if not msg:
empties = state.get("subtask_empty_count", 0) + 1
state["subtask_empty_count"] = empties
if empties == 1 or empties % 5 == 0:
debug = getattr(self.policy, "_last_select_message_debug", "") or ""
if debug:
push_log(
state,
f" [info] subtask gen empty (×{empties}); {debug}",
)
else:
push_log(
state,
f" [info] subtask gen returned empty (×{empties}) — "
"no tokens generated (head EOS-ing before any "
"non-special token).",
)
if msg and _looks_like_gibberish(msg):
# Bump a counter so the operator can see the model is
# struggling without spamming the log every tick. A first
# rejection still logs once so the failure is visible.
count = state.get("subtask_gibberish_count", 0) + 1
state["subtask_gibberish_count"] = count
if count == 1 or count % 30 == 0:
push_log(
state,
f" [info] subtask gen rejected (gibberish ×{count}): {msg[:60]!r}",
)
return None return None
if msg:
changed = set_if_changed(state, "current_subtask", msg, label="subtask") # ``plan_index`` advances by 1 each chunk boundary. Tracked via
if changed: # ``actions_dispatched`` (set by ``DispatchAction``): we advance
# Subtask change is a downstream trigger. # only when at least one new action was dispatched since the
state.setdefault("events_this_tick", []).append("subtask_change") # previous walker tick. Clamps to the last line so the robot
state["subtask_repeat_count"] = 0 # keeps executing the final subtask instead of going silent.
else: plan_index = int(state.get("plan_index", 0) or 0)
# Same accepted string regenerated — memorisation tell. dispatched = int(state.get("actions_dispatched", 0) or 0)
# Once this counter climbs past a few, you're seeing last_dispatched = state.get("_subtask_walker_last_dispatched")
# the model unable to move past the current subtask if last_dispatched is not None and dispatched > int(last_dispatched):
# despite the chunk having drained (visual scene may plan_index = min(plan_index + 1, len(plan_lines) - 1)
# have changed but the LM is replaying training state["plan_index"] = plan_index
# tokens). state["_subtask_walker_last_dispatched"] = dispatched
state["subtask_repeat_count"] = (
state.get("subtask_repeat_count", 0) + 1 new_subtask = plan_lines[plan_index]
) changed = set_if_changed(state, "current_subtask", new_subtask, label="subtask")
# Silently skip empty completions — common when the model if changed:
# warms up or generates only EOS; logging it every tick at state.setdefault("events_this_tick", []).append("subtask_change")
# ctrl_hz is just noise. # Surface diagnostics in the panel format the runtime expects.
state["last_subtask_raw"] = new_subtask
return None return None
@@ -663,6 +622,28 @@ class DispatchToolCalls(InferenceStep):
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
_PLAN_LINE_RE = re.compile(r"^\s*\d+[.)]\s*(.+?)\s*$")
def _parse_plan_lines(plan_text: str) -> list[str]:
"""Split a numbered-list plan string into its individual subtask lines.
Accepts both ``"1. foo"`` and ``"1) foo"`` formatting. Falls back to
splitting on newlines for non-numbered plans. Empty lines are
dropped. Returns the bare subtask text (number prefix stripped).
"""
if not plan_text:
return []
out: list[str] = []
for raw in plan_text.splitlines():
line = raw.strip()
if not line:
continue
m = _PLAN_LINE_RE.match(line)
out.append(m.group(1).strip() if m else line)
return out
def _looks_like_gibberish(text: str) -> bool: def _looks_like_gibberish(text: str) -> bool:
"""Heuristically detect generation that's clearly off the rails. """Heuristically detect generation that's clearly off the rails.