From a47e535b026dd2af448be4900e38f7972c32009a Mon Sep 17 00:00:00 2001 From: Pepijn Date: Tue, 5 May 2026 13:47:22 +0200 Subject: [PATCH] fix(smolvla2): per-recipe inference prompts to match training shape MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The four high-level steps shared one generic ``_control_context_messages`` that jammed task + plan + memory + completed_subtask into a single user message. The recipes in ``smolvla2_hirobot.yaml`` each have a *specific* multi-message layout (``memory_update``: ``user(task) → assistant(prev memory) → user(completed subtask)``; ``high_level_subtask``: ``user(task+plan+ memory) → user(current subtask)``; ``user_interjection_response``: ``user(task) → assistant(prev plan) → user(interjection)``). After ``apply_chat_template`` those layouts produce different prompts than the runtime's flattened single-user-turn version, and the model fell back to its dominant training mode (VQA JSON output) — generating ``":":":":":":...`` repetition. Add four per-recipe prompt builders (``_msgs_for_subtask``, ``_msgs_for_memory``, ``_msgs_for_interjection``, ``_msgs_for_vqa``), each mirroring its sub-recipe's exact message structure including the ``if_present`` skips. Wire each high-level step to its matching builder. Inference prompts now line up with what the model saw in training, so generation should produce coherent text instead of repeated tokens. Generic ``_control_context_messages`` is kept (still used by tests and the no-recipe fallback path). Co-Authored-By: Claude Opus 4.7 (1M context) --- .../policies/smolvla2/inference/steps.py | 123 ++++++++++++++++-- 1 file changed, 112 insertions(+), 11 deletions(-) diff --git a/src/lerobot/policies/smolvla2/inference/steps.py b/src/lerobot/policies/smolvla2/inference/steps.py index 015d03504..d30bc9dc9 100644 --- a/src/lerobot/policies/smolvla2/inference/steps.py +++ b/src/lerobot/policies/smolvla2/inference/steps.py @@ -245,7 +245,15 @@ def _strip_recipe_keys(m: dict[str, Any]) -> dict[str, Any]: @dataclass class HighLevelSubtaskFwd(InferenceStep): - """At ~1 Hz, ask the policy for the next subtask.""" + """At ~1 Hz, ask the policy for the next subtask. + + Mirrors the ``high_level_subtask`` recipe layout exactly: + + user: "${task}\\nPlan: ${plan}\\nMemory: ${memory}" + user: "Current subtask: ${subtask}" (if subtask present) + ↓ generate ↓ + assistant: + """ policy: Any = None observation_provider: Any = None @@ -258,7 +266,7 @@ class HighLevelSubtaskFwd(InferenceStep): def run(self, state: dict[str, Any]) -> dict[str, Any] | None: if self.policy is None or not state.get("task"): return None - ctx = _control_context_messages(state) + ctx = _msgs_for_subtask(state) observation = _maybe_observation(self.observation_provider) msg = _generate_with_policy( self.policy, ctx, observation=observation, state=state, label="subtask gen" @@ -275,7 +283,16 @@ class HighLevelSubtaskFwd(InferenceStep): @dataclass class MemoryUpdateFwd(InferenceStep): - """On subtask boundary, refresh the compressed memory.""" + """On subtask boundary, refresh the compressed memory. + + Mirrors the ``memory_update`` recipe layout exactly: + + user: "${task}" + assistant: "Previous memory: ${prior_memory}" (if prior memory) + user: "Completed subtask: ${completed_subtask}" (if subtask) + ↓ generate ↓ + assistant: + """ policy: Any = None observation_provider: Any = None @@ -285,7 +302,7 @@ class MemoryUpdateFwd(InferenceStep): # Don't consume the event — multiple steps may want to react. if self.policy is None: return None - ctx = _control_context_messages(state, include_completed=True) + ctx = _msgs_for_memory(state) observation = _maybe_observation(self.observation_provider) new_memory = _generate_with_policy( self.policy, ctx, observation=observation, state=state, label="memory gen" @@ -297,7 +314,16 @@ class MemoryUpdateFwd(InferenceStep): @dataclass class UserInterjectionFwd(InferenceStep): - """On stdin interjection, refresh the plan + emit a paired ``say``.""" + """On stdin interjection, refresh the plan + emit a paired ``say``. + + Mirrors the ``user_interjection_response`` recipe layout exactly: + + user: "${task}" + assistant: "Previous plan:\\n${prior_plan}" (if prior plan) + user: "${interjection}" (the new utterance) + ↓ generate ↓ + assistant: ...> + """ policy: Any = None observation_provider: Any = None @@ -306,10 +332,7 @@ class UserInterjectionFwd(InferenceStep): def run(self, state: dict[str, Any]) -> dict[str, Any] | None: if self.policy is None or not take_event(state, "user_interjection"): return None - ctx = _control_context_messages( - state, - extra_user=state.get("recent_interjection"), - ) + ctx = _msgs_for_interjection(state) observation = _maybe_observation(self.observation_provider) out = _generate_with_policy( self.policy, ctx, observation=observation, state=state, label="plan/say gen" @@ -340,7 +363,17 @@ class UserInterjectionFwd(InferenceStep): @dataclass class AskVQAFwd(InferenceStep): - """On stdin question, answer a frame-grounded VQA.""" + """On stdin question, answer a frame-grounded VQA. + + Mirrors the ``ask_vqa_*`` recipe layout exactly: a single user + turn carrying just the VQA question, plus the camera image block + in training (we drop the image at inference because the dataset's + image preprocessing doesn't match SmolVLM's vision tower input). + + user: + ↓ generate ↓ + assistant: + """ policy: Any = None observation_provider: Any = None @@ -352,7 +385,7 @@ class AskVQAFwd(InferenceStep): question = state.get("recent_vqa_query") if not question: return None - ctx = _control_context_messages(state, extra_user=question) + ctx = _msgs_for_vqa(question) observation = _maybe_observation(self.observation_provider) answer = _generate_with_policy( self.policy, ctx, observation=observation, state=state, label="vqa gen" @@ -426,6 +459,74 @@ def _control_context_messages( return msgs +# --------------------------------------------------------------------------- +# Per-recipe prompt builders. Each one mirrors a single sub-recipe's +# message layout in ``smolvla2_hirobot.yaml`` so the chat-templated +# prompt at inference matches what the model saw during training. +# Generic ``_control_context_messages`` is kept around as a fallback +# for ad-hoc callers but the four high-level steps now use these. +# --------------------------------------------------------------------------- + + +def _msgs_for_subtask(state: dict[str, Any]) -> list[dict[str, Any]]: + """``high_level_subtask`` recipe layout.""" + head_parts = [state.get("task") or ""] + if state.get("current_plan"): + head_parts.append(f"Plan: {state['current_plan']}") + if state.get("current_memory"): + head_parts.append(f"Memory: {state['current_memory']}") + msgs: list[dict[str, Any]] = [ + {"role": "user", "content": "\n".join(head_parts)} + ] + if state.get("current_subtask"): + msgs.append( + {"role": "user", "content": f"Current subtask: {state['current_subtask']}"} + ) + return msgs + + +def _msgs_for_memory(state: dict[str, Any]) -> list[dict[str, Any]]: + """``memory_update`` recipe layout.""" + msgs: list[dict[str, Any]] = [ + {"role": "user", "content": state.get("task") or ""} + ] + if state.get("current_memory"): + msgs.append( + { + "role": "assistant", + "content": f"Previous memory: {state['current_memory']}", + } + ) + if state.get("current_subtask"): + msgs.append( + { + "role": "user", + "content": f"Completed subtask: {state['current_subtask']}", + } + ) + return msgs + + +def _msgs_for_interjection(state: dict[str, Any]) -> list[dict[str, Any]]: + """``user_interjection_response`` recipe layout.""" + msgs: list[dict[str, Any]] = [ + {"role": "user", "content": state.get("task") or ""} + ] + if state.get("current_plan"): + msgs.append( + {"role": "assistant", "content": f"Previous plan:\n{state['current_plan']}"} + ) + interjection = state.get("recent_interjection") + if interjection: + msgs.append({"role": "user", "content": interjection}) + return msgs + + +def _msgs_for_vqa(question: str) -> list[dict[str, Any]]: + """``ask_vqa_*`` recipe layout (text-only at inference).""" + return [{"role": "user", "content": question}] + + def _maybe_observation(provider: Any) -> dict | None: """Pull one observation from ``provider`` if it's set, else ``None``.