mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 20:50:02 +00:00
refactor(inference): subtask = head of plan, drop AR generation
Replaced the LM-head subtask generator with a deterministic plan-walker. The plan is already a numbered list of still-todo subtasks (produced by ``plan_generation``); the current subtask is just the head of that list, popped one position per chunk-dispatch boundary. Why --- The previous design called ``select_message`` every ~1 Hz to auto-regressively generate the current subtask string. On a small training set with a chat-pretrained backbone, the LM head collapses into n-gram loops under greedy decoding — "the robot arm extends and retracts and retracts and retracts from the beige surface and retracts from the surface" — which then poisons the action expert's language conditioning, because the same garbage string is what gets rendered into its prefix. Repetition penalty / no_repeat_ngram_size help but are bandaids. Walking the plan deterministically eliminates the failure mode at the root: the action expert always sees a real, recipe-aligned subtask phrase. Inference is also interpretable now — every robot motion traces back to a specific plan line — and we save the per- chunk ``select_message`` round trip (~2 s on MPS). The LM head is still trained (``high_level_subtask`` text-CE) and still used at runtime for one-shot plan generation, interjection replanning, and VQA. We just don't sit on the AR subtask loop. Advance trigger: ``actions_dispatched`` increment — each new chunk the action expert produced and dispatched advances ``plan_index`` by one. Clamps to the last plan line so the robot keeps executing the final subtask if it runs off the end instead of going silent. Falls back to the task string when no plan is set yet (e.g. before ``plan_generation`` has fired this episode). Added ``_parse_plan_lines`` helper for stripping the ``"N. "`` prefix off plan items. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -346,118 +346,77 @@ def _strip_recipe_keys(m: dict[str, Any]) -> dict[str, Any]:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class HighLevelSubtaskFwd(InferenceStep):
|
class HighLevelSubtaskFwd(InferenceStep):
|
||||||
"""At ~1 Hz, ask the policy for the next subtask.
|
"""Deterministic plan walker — current subtask is the head of the plan.
|
||||||
|
|
||||||
Mirrors the ``high_level_subtask`` recipe layout exactly:
|
Rationale
|
||||||
|
---------
|
||||||
|
The training-time ``plan_generation`` recipe summarises subtasks as
|
||||||
|
a numbered list of still-todo items. The ``high_level_subtask``
|
||||||
|
recipe supervises subtask text auto-regressively. Empirically the
|
||||||
|
LM head's AR generation collapses on small datasets and produces
|
||||||
|
repetitive / off-distribution subtasks
|
||||||
|
("…extends and retracts and retracts…"), which corrupts the
|
||||||
|
action expert's conditioning.
|
||||||
|
|
||||||
user: "${task}\\nPlan: ${plan}\\nMemory: ${memory}"
|
Instead of trusting the LM head's free generation, walk the plan:
|
||||||
user: "Current subtask: ${subtask}" (if subtask present)
|
|
||||||
↓ generate ↓
|
* The plan is a string ``"1. <subtask>\\n2. <subtask>\\n…"``.
|
||||||
assistant: <next subtask>
|
* The current subtask is the line at ``state["plan_index"]``.
|
||||||
|
* After each chunk dispatch (action queue empties), advance the
|
||||||
|
index by one.
|
||||||
|
|
||||||
|
This is interpretable, can't repetition-collapse, and stays in
|
||||||
|
lockstep with the deterministic plan that ``plan_generation``
|
||||||
|
emits.
|
||||||
|
|
||||||
|
The LM head is still trained and still used for plan generation
|
||||||
|
(one-shot at episode start), VQA, and interjection replanning —
|
||||||
|
just not for the always-on per-chunk subtask loop.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
policy: Any = None
|
policy: Any = None
|
||||||
observation_provider: Any = None
|
observation_provider: Any = None # kept for signature compatibility
|
||||||
"""Same shape as ``LowLevelForward.observation_provider``. When
|
|
||||||
set, the resulting observation is merged into ``select_message``'s
|
|
||||||
batch so text generation runs against real video + state."""
|
|
||||||
|
|
||||||
trigger: Trigger = field(default_factory=lambda: HzTrigger(hz=1.0))
|
trigger: Trigger = field(default_factory=lambda: HzTrigger(hz=1.0))
|
||||||
|
|
||||||
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||||
if self.policy is None or not state.get("task"):
|
if not state.get("task"):
|
||||||
return None
|
return None
|
||||||
# Gate to chunk boundaries: only generate a fresh subtask when
|
# Only advance at chunk boundaries (action queue empty) — same
|
||||||
# the action queue is empty (i.e. right before LowLevelForward
|
# gating as the original AR path, so subtask transitions stay
|
||||||
# refreshes the chunk). ``select_message`` takes ~2 s on MPS,
|
# aligned with the dispatch cycle.
|
||||||
# and running it every loop iteration starves DispatchAction
|
|
||||||
# at ctrl_hz=30 — the queue drains at ~0.4 actions/sec instead
|
|
||||||
# of 30/sec and the robot barely moves. Tying it to the same
|
|
||||||
# "queue empty" condition as the chunk refresh produces a
|
|
||||||
# clean sense → think → act cycle.
|
|
||||||
queue = state.get("action_queue") or []
|
queue = state.get("action_queue") or []
|
||||||
if len(queue) > 0:
|
if len(queue) > 0:
|
||||||
return None
|
return None
|
||||||
ctx = _msgs_for_subtask(state)
|
|
||||||
observation = _maybe_observation(self.observation_provider)
|
plan_lines = _parse_plan_lines(state.get("current_plan") or "")
|
||||||
# Default: greedy argmax, no min_new_tokens, no special-token
|
if not plan_lines:
|
||||||
# suppression — matches training. Operator can override via
|
# No plan available yet (e.g. plan_generation hasn't fired
|
||||||
# ``--text_min_new_tokens=N --text_temperature=T --text_top_p=P``
|
# on this episode). Fall back to the task string so the
|
||||||
# on the CLI; useful for under-trained checkpoints whose LM
|
# action expert at least has *something* to condition on.
|
||||||
# head still favours EOS at position 0 (pre-trained chat
|
fallback = state.get("task") or ""
|
||||||
# backbone's short-turn prior hasn't been fully overridden
|
if fallback and state.get("current_subtask") != fallback:
|
||||||
# by the fine-tuning supervision yet).
|
set_if_changed(state, "current_subtask", fallback, label="subtask")
|
||||||
msg = _generate_with_policy(
|
|
||||||
self.policy,
|
|
||||||
ctx,
|
|
||||||
observation=observation,
|
|
||||||
state=state,
|
|
||||||
label="subtask gen",
|
|
||||||
min_new_tokens=int(state.get("text_gen_min_new_tokens") or 0),
|
|
||||||
temperature=float(state.get("text_gen_temperature") or 0.0),
|
|
||||||
top_p=float(state.get("text_gen_top_p") or 1.0),
|
|
||||||
repetition_penalty=float(state.get("text_gen_repetition_penalty") or 1.0),
|
|
||||||
no_repeat_ngram_size=int(state.get("text_gen_no_repeat_ngram_size") or 0),
|
|
||||||
)
|
|
||||||
# Diagnostics: surface what the model is *actually* producing
|
|
||||||
# at chunk boundaries, even when the output gets rejected or
|
|
||||||
# repeats. Memorisation collapse looks like "same accepted
|
|
||||||
# subtask N times in a row" or "gibberish_count rising while
|
|
||||||
# current_subtask is stuck". The state panel renders these.
|
|
||||||
state["last_subtask_raw"] = msg or ""
|
|
||||||
# Persistent empty completion is its own failure mode (model
|
|
||||||
# immediately EOS-es from the chat-template generation
|
|
||||||
# prompt) — surface it once every N occurrences so the
|
|
||||||
# operator can distinguish "generation failing silently"
|
|
||||||
# from "generating fine but filter rejecting".
|
|
||||||
if not msg:
|
|
||||||
empties = state.get("subtask_empty_count", 0) + 1
|
|
||||||
state["subtask_empty_count"] = empties
|
|
||||||
if empties == 1 or empties % 5 == 0:
|
|
||||||
debug = getattr(self.policy, "_last_select_message_debug", "") or ""
|
|
||||||
if debug:
|
|
||||||
push_log(
|
|
||||||
state,
|
|
||||||
f" [info] subtask gen empty (×{empties}); {debug}",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
push_log(
|
|
||||||
state,
|
|
||||||
f" [info] subtask gen returned empty (×{empties}) — "
|
|
||||||
"no tokens generated (head EOS-ing before any "
|
|
||||||
"non-special token).",
|
|
||||||
)
|
|
||||||
if msg and _looks_like_gibberish(msg):
|
|
||||||
# Bump a counter so the operator can see the model is
|
|
||||||
# struggling without spamming the log every tick. A first
|
|
||||||
# rejection still logs once so the failure is visible.
|
|
||||||
count = state.get("subtask_gibberish_count", 0) + 1
|
|
||||||
state["subtask_gibberish_count"] = count
|
|
||||||
if count == 1 or count % 30 == 0:
|
|
||||||
push_log(
|
|
||||||
state,
|
|
||||||
f" [info] subtask gen rejected (gibberish ×{count}): {msg[:60]!r}",
|
|
||||||
)
|
|
||||||
return None
|
return None
|
||||||
if msg:
|
|
||||||
changed = set_if_changed(state, "current_subtask", msg, label="subtask")
|
# ``plan_index`` advances by 1 each chunk boundary. Tracked via
|
||||||
if changed:
|
# ``actions_dispatched`` (set by ``DispatchAction``): we advance
|
||||||
# Subtask change is a downstream trigger.
|
# only when at least one new action was dispatched since the
|
||||||
state.setdefault("events_this_tick", []).append("subtask_change")
|
# previous walker tick. Clamps to the last line so the robot
|
||||||
state["subtask_repeat_count"] = 0
|
# keeps executing the final subtask instead of going silent.
|
||||||
else:
|
plan_index = int(state.get("plan_index", 0) or 0)
|
||||||
# Same accepted string regenerated — memorisation tell.
|
dispatched = int(state.get("actions_dispatched", 0) or 0)
|
||||||
# Once this counter climbs past a few, you're seeing
|
last_dispatched = state.get("_subtask_walker_last_dispatched")
|
||||||
# the model unable to move past the current subtask
|
if last_dispatched is not None and dispatched > int(last_dispatched):
|
||||||
# despite the chunk having drained (visual scene may
|
plan_index = min(plan_index + 1, len(plan_lines) - 1)
|
||||||
# have changed but the LM is replaying training
|
state["plan_index"] = plan_index
|
||||||
# tokens).
|
state["_subtask_walker_last_dispatched"] = dispatched
|
||||||
state["subtask_repeat_count"] = (
|
|
||||||
state.get("subtask_repeat_count", 0) + 1
|
new_subtask = plan_lines[plan_index]
|
||||||
)
|
changed = set_if_changed(state, "current_subtask", new_subtask, label="subtask")
|
||||||
# Silently skip empty completions — common when the model
|
if changed:
|
||||||
# warms up or generates only EOS; logging it every tick at
|
state.setdefault("events_this_tick", []).append("subtask_change")
|
||||||
# ctrl_hz is just noise.
|
# Surface diagnostics in the panel format the runtime expects.
|
||||||
|
state["last_subtask_raw"] = new_subtask
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@@ -663,6 +622,28 @@ class DispatchToolCalls(InferenceStep):
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
_PLAN_LINE_RE = re.compile(r"^\s*\d+[.)]\s*(.+?)\s*$")
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_plan_lines(plan_text: str) -> list[str]:
|
||||||
|
"""Split a numbered-list plan string into its individual subtask lines.
|
||||||
|
|
||||||
|
Accepts both ``"1. foo"`` and ``"1) foo"`` formatting. Falls back to
|
||||||
|
splitting on newlines for non-numbered plans. Empty lines are
|
||||||
|
dropped. Returns the bare subtask text (number prefix stripped).
|
||||||
|
"""
|
||||||
|
if not plan_text:
|
||||||
|
return []
|
||||||
|
out: list[str] = []
|
||||||
|
for raw in plan_text.splitlines():
|
||||||
|
line = raw.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
m = _PLAN_LINE_RE.match(line)
|
||||||
|
out.append(m.group(1).strip() if m else line)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
def _looks_like_gibberish(text: str) -> bool:
|
def _looks_like_gibberish(text: str) -> bool:
|
||||||
"""Heuristically detect generation that's clearly off the rails.
|
"""Heuristically detect generation that's clearly off the rails.
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user