refactor(inference): subtask = head of plan, drop AR generation

Replaced the LM-head subtask generator with a deterministic
plan-walker. The plan is already a numbered list of still-todo
subtasks (produced by ``plan_generation``); the current subtask is
just the head of that list, popped one position per chunk-dispatch
boundary.

Why
---
The previous design called ``select_message`` every ~1 Hz to
auto-regressively generate the current subtask string. On a small
training set with a chat-pretrained backbone, the LM head collapses
into n-gram loops under greedy decoding —

  "the robot arm extends and retracts and retracts and retracts
   from the beige surface and retracts from the surface"

— which then poisons the action expert's language conditioning,
because the same garbage string is what gets rendered into its
prefix. Repetition penalty / no_repeat_ngram_size help but are
bandaids.

Walking the plan deterministically eliminates the failure mode at
the root: the action expert always sees a real, recipe-aligned
subtask phrase. Inference is also interpretable now — every robot
motion traces back to a specific plan line — and we save the per-
chunk ``select_message`` round trip (~2 s on MPS).

The LM head is still trained (``high_level_subtask`` text-CE) and
still used at runtime for one-shot plan generation, interjection
replanning, and VQA. We just don't sit on the AR subtask loop.

Advance trigger: ``actions_dispatched`` increment — each new chunk
the action expert produced and dispatched advances ``plan_index``
by one. Clamps to the last plan line so the robot keeps executing
the final subtask if it runs off the end instead of going silent.
Falls back to the task string when no plan is set yet (e.g. before
``plan_generation`` has fired this episode).

Added ``_parse_plan_lines`` helper for stripping the ``"N. "``
prefix off plan items.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-14 13:39:42 +02:00
parent db2972fb6c
commit f60babc946
@@ -346,118 +346,77 @@ def _strip_recipe_keys(m: dict[str, Any]) -> dict[str, Any]:
@dataclass
class HighLevelSubtaskFwd(InferenceStep):
"""At ~1 Hz, ask the policy for the next subtask.
"""Deterministic plan walker — current subtask is the head of the plan.
Mirrors the ``high_level_subtask`` recipe layout exactly:
Rationale
---------
The training-time ``plan_generation`` recipe summarises subtasks as
a numbered list of still-todo items. The ``high_level_subtask``
recipe supervises subtask text auto-regressively. Empirically the
LM head's AR generation collapses on small datasets and produces
repetitive / off-distribution subtasks
("…extends and retracts and retracts…"), which corrupts the
action expert's conditioning.
user: "${task}\\nPlan: ${plan}\\nMemory: ${memory}"
user: "Current subtask: ${subtask}" (if subtask present)
↓ generate ↓
assistant: <next subtask>
Instead of trusting the LM head's free generation, walk the plan:
* The plan is a string ``"1. <subtask>\\n2. <subtask>\\n…"``.
* The current subtask is the line at ``state["plan_index"]``.
* After each chunk dispatch (action queue empties), advance the
index by one.
This is interpretable, can't repetition-collapse, and stays in
lockstep with the deterministic plan that ``plan_generation``
emits.
The LM head is still trained and still used for plan generation
(one-shot at episode start), VQA, and interjection replanning —
just not for the always-on per-chunk subtask loop.
"""
policy: Any = None
observation_provider: Any = None
"""Same shape as ``LowLevelForward.observation_provider``. When
set, the resulting observation is merged into ``select_message``'s
batch so text generation runs against real video + state."""
observation_provider: Any = None # kept for signature compatibility
trigger: Trigger = field(default_factory=lambda: HzTrigger(hz=1.0))
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
if self.policy is None or not state.get("task"):
if not state.get("task"):
return None
# Gate to chunk boundaries: only generate a fresh subtask when
# the action queue is empty (i.e. right before LowLevelForward
# refreshes the chunk). ``select_message`` takes ~2 s on MPS,
# and running it every loop iteration starves DispatchAction
# at ctrl_hz=30 — the queue drains at ~0.4 actions/sec instead
# of 30/sec and the robot barely moves. Tying it to the same
# "queue empty" condition as the chunk refresh produces a
# clean sense → think → act cycle.
# Only advance at chunk boundaries (action queue empty) — same
# gating as the original AR path, so subtask transitions stay
# aligned with the dispatch cycle.
queue = state.get("action_queue") or []
if len(queue) > 0:
return None
ctx = _msgs_for_subtask(state)
observation = _maybe_observation(self.observation_provider)
# Default: greedy argmax, no min_new_tokens, no special-token
# suppression — matches training. Operator can override via
# ``--text_min_new_tokens=N --text_temperature=T --text_top_p=P``
# on the CLI; useful for under-trained checkpoints whose LM
# head still favours EOS at position 0 (pre-trained chat
# backbone's short-turn prior hasn't been fully overridden
# by the fine-tuning supervision yet).
msg = _generate_with_policy(
self.policy,
ctx,
observation=observation,
state=state,
label="subtask gen",
min_new_tokens=int(state.get("text_gen_min_new_tokens") or 0),
temperature=float(state.get("text_gen_temperature") or 0.0),
top_p=float(state.get("text_gen_top_p") or 1.0),
repetition_penalty=float(state.get("text_gen_repetition_penalty") or 1.0),
no_repeat_ngram_size=int(state.get("text_gen_no_repeat_ngram_size") or 0),
)
# Diagnostics: surface what the model is *actually* producing
# at chunk boundaries, even when the output gets rejected or
# repeats. Memorisation collapse looks like "same accepted
# subtask N times in a row" or "gibberish_count rising while
# current_subtask is stuck". The state panel renders these.
state["last_subtask_raw"] = msg or ""
# Persistent empty completion is its own failure mode (model
# immediately EOS-es from the chat-template generation
# prompt) — surface it once every N occurrences so the
# operator can distinguish "generation failing silently"
# from "generating fine but filter rejecting".
if not msg:
empties = state.get("subtask_empty_count", 0) + 1
state["subtask_empty_count"] = empties
if empties == 1 or empties % 5 == 0:
debug = getattr(self.policy, "_last_select_message_debug", "") or ""
if debug:
push_log(
state,
f" [info] subtask gen empty (×{empties}); {debug}",
)
else:
push_log(
state,
f" [info] subtask gen returned empty (×{empties}) — "
"no tokens generated (head EOS-ing before any "
"non-special token).",
)
if msg and _looks_like_gibberish(msg):
# Bump a counter so the operator can see the model is
# struggling without spamming the log every tick. A first
# rejection still logs once so the failure is visible.
count = state.get("subtask_gibberish_count", 0) + 1
state["subtask_gibberish_count"] = count
if count == 1 or count % 30 == 0:
push_log(
state,
f" [info] subtask gen rejected (gibberish ×{count}): {msg[:60]!r}",
)
plan_lines = _parse_plan_lines(state.get("current_plan") or "")
if not plan_lines:
# No plan available yet (e.g. plan_generation hasn't fired
# on this episode). Fall back to the task string so the
# action expert at least has *something* to condition on.
fallback = state.get("task") or ""
if fallback and state.get("current_subtask") != fallback:
set_if_changed(state, "current_subtask", fallback, label="subtask")
return None
if msg:
changed = set_if_changed(state, "current_subtask", msg, label="subtask")
if changed:
# Subtask change is a downstream trigger.
state.setdefault("events_this_tick", []).append("subtask_change")
state["subtask_repeat_count"] = 0
else:
# Same accepted string regenerated — memorisation tell.
# Once this counter climbs past a few, you're seeing
# the model unable to move past the current subtask
# despite the chunk having drained (visual scene may
# have changed but the LM is replaying training
# tokens).
state["subtask_repeat_count"] = (
state.get("subtask_repeat_count", 0) + 1
)
# Silently skip empty completions — common when the model
# warms up or generates only EOS; logging it every tick at
# ctrl_hz is just noise.
# ``plan_index`` advances by 1 each chunk boundary. Tracked via
# ``actions_dispatched`` (set by ``DispatchAction``): we advance
# only when at least one new action was dispatched since the
# previous walker tick. Clamps to the last line so the robot
# keeps executing the final subtask instead of going silent.
plan_index = int(state.get("plan_index", 0) or 0)
dispatched = int(state.get("actions_dispatched", 0) or 0)
last_dispatched = state.get("_subtask_walker_last_dispatched")
if last_dispatched is not None and dispatched > int(last_dispatched):
plan_index = min(plan_index + 1, len(plan_lines) - 1)
state["plan_index"] = plan_index
state["_subtask_walker_last_dispatched"] = dispatched
new_subtask = plan_lines[plan_index]
changed = set_if_changed(state, "current_subtask", new_subtask, label="subtask")
if changed:
state.setdefault("events_this_tick", []).append("subtask_change")
# Surface diagnostics in the panel format the runtime expects.
state["last_subtask_raw"] = new_subtask
return None
@@ -663,6 +622,28 @@ class DispatchToolCalls(InferenceStep):
# ---------------------------------------------------------------------------
_PLAN_LINE_RE = re.compile(r"^\s*\d+[.)]\s*(.+?)\s*$")
def _parse_plan_lines(plan_text: str) -> list[str]:
"""Split a numbered-list plan string into its individual subtask lines.
Accepts both ``"1. foo"`` and ``"1) foo"`` formatting. Falls back to
splitting on newlines for non-numbered plans. Empty lines are
dropped. Returns the bare subtask text (number prefix stripped).
"""
if not plan_text:
return []
out: list[str] = []
for raw in plan_text.splitlines():
line = raw.strip()
if not line:
continue
m = _PLAN_LINE_RE.match(line)
out.append(m.group(1).strip() if m else line)
return out
def _looks_like_gibberish(text: str) -> bool:
"""Heuristically detect generation that's clearly off the rails.