diff --git a/src/lerobot/policies/smolvla2/inference/steps.py b/src/lerobot/policies/smolvla2/inference/steps.py index ab86afdc5..92449895c 100644 --- a/src/lerobot/policies/smolvla2/inference/steps.py +++ b/src/lerobot/policies/smolvla2/inference/steps.py @@ -346,118 +346,77 @@ def _strip_recipe_keys(m: dict[str, Any]) -> dict[str, Any]: @dataclass class HighLevelSubtaskFwd(InferenceStep): - """At ~1 Hz, ask the policy for the next subtask. + """Deterministic plan walker — current subtask is the head of the plan. - Mirrors the ``high_level_subtask`` recipe layout exactly: + Rationale + --------- + The training-time ``plan_generation`` recipe summarises subtasks as + a numbered list of still-todo items. The ``high_level_subtask`` + recipe supervises subtask text auto-regressively. Empirically the + LM head's AR generation collapses on small datasets and produces + repetitive / off-distribution subtasks + ("…extends and retracts and retracts…"), which corrupts the + action expert's conditioning. - user: "${task}\\nPlan: ${plan}\\nMemory: ${memory}" - user: "Current subtask: ${subtask}" (if subtask present) - ↓ generate ↓ - assistant: + Instead of trusting the LM head's free generation, walk the plan: + + * The plan is a string ``"1. \\n2. \\n…"``. + * The current subtask is the line at ``state["plan_index"]``. + * After each chunk dispatch (action queue empties), advance the + index by one. + + This is interpretable, can't repetition-collapse, and stays in + lockstep with the deterministic plan that ``plan_generation`` + emits. + + The LM head is still trained and still used for plan generation + (one-shot at episode start), VQA, and interjection replanning — + just not for the always-on per-chunk subtask loop. """ policy: Any = None - observation_provider: Any = None - """Same shape as ``LowLevelForward.observation_provider``. When - set, the resulting observation is merged into ``select_message``'s - batch so text generation runs against real video + state.""" - + observation_provider: Any = None # kept for signature compatibility trigger: Trigger = field(default_factory=lambda: HzTrigger(hz=1.0)) def run(self, state: dict[str, Any]) -> dict[str, Any] | None: - if self.policy is None or not state.get("task"): + if not state.get("task"): return None - # Gate to chunk boundaries: only generate a fresh subtask when - # the action queue is empty (i.e. right before LowLevelForward - # refreshes the chunk). ``select_message`` takes ~2 s on MPS, - # and running it every loop iteration starves DispatchAction - # at ctrl_hz=30 — the queue drains at ~0.4 actions/sec instead - # of 30/sec and the robot barely moves. Tying it to the same - # "queue empty" condition as the chunk refresh produces a - # clean sense → think → act cycle. + # Only advance at chunk boundaries (action queue empty) — same + # gating as the original AR path, so subtask transitions stay + # aligned with the dispatch cycle. queue = state.get("action_queue") or [] if len(queue) > 0: return None - ctx = _msgs_for_subtask(state) - observation = _maybe_observation(self.observation_provider) - # Default: greedy argmax, no min_new_tokens, no special-token - # suppression — matches training. Operator can override via - # ``--text_min_new_tokens=N --text_temperature=T --text_top_p=P`` - # on the CLI; useful for under-trained checkpoints whose LM - # head still favours EOS at position 0 (pre-trained chat - # backbone's short-turn prior hasn't been fully overridden - # by the fine-tuning supervision yet). - msg = _generate_with_policy( - self.policy, - ctx, - observation=observation, - state=state, - label="subtask gen", - min_new_tokens=int(state.get("text_gen_min_new_tokens") or 0), - temperature=float(state.get("text_gen_temperature") or 0.0), - top_p=float(state.get("text_gen_top_p") or 1.0), - repetition_penalty=float(state.get("text_gen_repetition_penalty") or 1.0), - no_repeat_ngram_size=int(state.get("text_gen_no_repeat_ngram_size") or 0), - ) - # Diagnostics: surface what the model is *actually* producing - # at chunk boundaries, even when the output gets rejected or - # repeats. Memorisation collapse looks like "same accepted - # subtask N times in a row" or "gibberish_count rising while - # current_subtask is stuck". The state panel renders these. - state["last_subtask_raw"] = msg or "" - # Persistent empty completion is its own failure mode (model - # immediately EOS-es from the chat-template generation - # prompt) — surface it once every N occurrences so the - # operator can distinguish "generation failing silently" - # from "generating fine but filter rejecting". - if not msg: - empties = state.get("subtask_empty_count", 0) + 1 - state["subtask_empty_count"] = empties - if empties == 1 or empties % 5 == 0: - debug = getattr(self.policy, "_last_select_message_debug", "") or "" - if debug: - push_log( - state, - f" [info] subtask gen empty (×{empties}); {debug}", - ) - else: - push_log( - state, - f" [info] subtask gen returned empty (×{empties}) — " - "no tokens generated (head EOS-ing before any " - "non-special token).", - ) - if msg and _looks_like_gibberish(msg): - # Bump a counter so the operator can see the model is - # struggling without spamming the log every tick. A first - # rejection still logs once so the failure is visible. - count = state.get("subtask_gibberish_count", 0) + 1 - state["subtask_gibberish_count"] = count - if count == 1 or count % 30 == 0: - push_log( - state, - f" [info] subtask gen rejected (gibberish ×{count}): {msg[:60]!r}", - ) + + plan_lines = _parse_plan_lines(state.get("current_plan") or "") + if not plan_lines: + # No plan available yet (e.g. plan_generation hasn't fired + # on this episode). Fall back to the task string so the + # action expert at least has *something* to condition on. + fallback = state.get("task") or "" + if fallback and state.get("current_subtask") != fallback: + set_if_changed(state, "current_subtask", fallback, label="subtask") return None - if msg: - changed = set_if_changed(state, "current_subtask", msg, label="subtask") - if changed: - # Subtask change is a downstream trigger. - state.setdefault("events_this_tick", []).append("subtask_change") - state["subtask_repeat_count"] = 0 - else: - # Same accepted string regenerated — memorisation tell. - # Once this counter climbs past a few, you're seeing - # the model unable to move past the current subtask - # despite the chunk having drained (visual scene may - # have changed but the LM is replaying training - # tokens). - state["subtask_repeat_count"] = ( - state.get("subtask_repeat_count", 0) + 1 - ) - # Silently skip empty completions — common when the model - # warms up or generates only EOS; logging it every tick at - # ctrl_hz is just noise. + + # ``plan_index`` advances by 1 each chunk boundary. Tracked via + # ``actions_dispatched`` (set by ``DispatchAction``): we advance + # only when at least one new action was dispatched since the + # previous walker tick. Clamps to the last line so the robot + # keeps executing the final subtask instead of going silent. + plan_index = int(state.get("plan_index", 0) or 0) + dispatched = int(state.get("actions_dispatched", 0) or 0) + last_dispatched = state.get("_subtask_walker_last_dispatched") + if last_dispatched is not None and dispatched > int(last_dispatched): + plan_index = min(plan_index + 1, len(plan_lines) - 1) + state["plan_index"] = plan_index + state["_subtask_walker_last_dispatched"] = dispatched + + new_subtask = plan_lines[plan_index] + changed = set_if_changed(state, "current_subtask", new_subtask, label="subtask") + if changed: + state.setdefault("events_this_tick", []).append("subtask_change") + # Surface diagnostics in the panel format the runtime expects. + state["last_subtask_raw"] = new_subtask return None @@ -663,6 +622,28 @@ class DispatchToolCalls(InferenceStep): # --------------------------------------------------------------------------- +_PLAN_LINE_RE = re.compile(r"^\s*\d+[.)]\s*(.+?)\s*$") + + +def _parse_plan_lines(plan_text: str) -> list[str]: + """Split a numbered-list plan string into its individual subtask lines. + + Accepts both ``"1. foo"`` and ``"1) foo"`` formatting. Falls back to + splitting on newlines for non-numbered plans. Empty lines are + dropped. Returns the bare subtask text (number prefix stripped). + """ + if not plan_text: + return [] + out: list[str] = [] + for raw in plan_text.splitlines(): + line = raw.strip() + if not line: + continue + m = _PLAN_LINE_RE.match(line) + out.append(m.group(1).strip() if m else line) + return out + + def _looks_like_gibberish(text: str) -> bool: """Heuristically detect generation that's clearly off the rails.