fix(smolvla2): per-recipe inference prompts to match training shape

The four high-level steps shared one generic
``_control_context_messages`` that jammed task + plan + memory +
completed_subtask into a single user message. The recipes in
``smolvla2_hirobot.yaml`` each have a *specific* multi-message layout
(``memory_update``: ``user(task) → assistant(prev memory) →
user(completed subtask)``; ``high_level_subtask``: ``user(task+plan+
memory) → user(current subtask)``; ``user_interjection_response``:
``user(task) → assistant(prev plan) → user(interjection)``). After
``apply_chat_template`` those layouts produce different prompts than
the runtime's flattened single-user-turn version, and the model fell
back to its dominant training mode (VQA JSON output) — generating
``":":":":":":...`` repetition.

Add four per-recipe prompt builders (``_msgs_for_subtask``,
``_msgs_for_memory``, ``_msgs_for_interjection``, ``_msgs_for_vqa``),
each mirroring its sub-recipe's exact message structure including
the ``if_present`` skips. Wire each high-level step to its matching
builder. Inference prompts now line up with what the model saw in
training, so generation should produce coherent text instead of
repeated tokens.

Generic ``_control_context_messages`` is kept (still used by tests
and the no-recipe fallback path).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-05 13:47:22 +02:00
parent 6d9b431b54
commit a47e535b02
+112 -11
View File
@@ -245,7 +245,15 @@ def _strip_recipe_keys(m: dict[str, Any]) -> dict[str, Any]:
@dataclass @dataclass
class HighLevelSubtaskFwd(InferenceStep): class HighLevelSubtaskFwd(InferenceStep):
"""At ~1 Hz, ask the policy for the next subtask.""" """At ~1 Hz, ask the policy for the next subtask.
Mirrors the ``high_level_subtask`` recipe layout exactly:
user: "${task}\\nPlan: ${plan}\\nMemory: ${memory}"
user: "Current subtask: ${subtask}" (if subtask present)
↓ generate ↓
assistant: <next subtask>
"""
policy: Any = None policy: Any = None
observation_provider: Any = None observation_provider: Any = None
@@ -258,7 +266,7 @@ class HighLevelSubtaskFwd(InferenceStep):
def run(self, state: dict[str, Any]) -> dict[str, Any] | None: def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
if self.policy is None or not state.get("task"): if self.policy is None or not state.get("task"):
return None return None
ctx = _control_context_messages(state) ctx = _msgs_for_subtask(state)
observation = _maybe_observation(self.observation_provider) observation = _maybe_observation(self.observation_provider)
msg = _generate_with_policy( msg = _generate_with_policy(
self.policy, ctx, observation=observation, state=state, label="subtask gen" self.policy, ctx, observation=observation, state=state, label="subtask gen"
@@ -275,7 +283,16 @@ class HighLevelSubtaskFwd(InferenceStep):
@dataclass @dataclass
class MemoryUpdateFwd(InferenceStep): class MemoryUpdateFwd(InferenceStep):
"""On subtask boundary, refresh the compressed memory.""" """On subtask boundary, refresh the compressed memory.
Mirrors the ``memory_update`` recipe layout exactly:
user: "${task}"
assistant: "Previous memory: ${prior_memory}" (if prior memory)
user: "Completed subtask: ${completed_subtask}" (if subtask)
↓ generate ↓
assistant: <new memory>
"""
policy: Any = None policy: Any = None
observation_provider: Any = None observation_provider: Any = None
@@ -285,7 +302,7 @@ class MemoryUpdateFwd(InferenceStep):
# Don't consume the event — multiple steps may want to react. # Don't consume the event — multiple steps may want to react.
if self.policy is None: if self.policy is None:
return None return None
ctx = _control_context_messages(state, include_completed=True) ctx = _msgs_for_memory(state)
observation = _maybe_observation(self.observation_provider) observation = _maybe_observation(self.observation_provider)
new_memory = _generate_with_policy( new_memory = _generate_with_policy(
self.policy, ctx, observation=observation, state=state, label="memory gen" self.policy, ctx, observation=observation, state=state, label="memory gen"
@@ -297,7 +314,16 @@ class MemoryUpdateFwd(InferenceStep):
@dataclass @dataclass
class UserInterjectionFwd(InferenceStep): class UserInterjectionFwd(InferenceStep):
"""On stdin interjection, refresh the plan + emit a paired ``say``.""" """On stdin interjection, refresh the plan + emit a paired ``say``.
Mirrors the ``user_interjection_response`` recipe layout exactly:
user: "${task}"
assistant: "Previous plan:\\n${prior_plan}" (if prior plan)
user: "${interjection}" (the new utterance)
↓ generate ↓
assistant: <plan + <say>...</say>>
"""
policy: Any = None policy: Any = None
observation_provider: Any = None observation_provider: Any = None
@@ -306,10 +332,7 @@ class UserInterjectionFwd(InferenceStep):
def run(self, state: dict[str, Any]) -> dict[str, Any] | None: def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
if self.policy is None or not take_event(state, "user_interjection"): if self.policy is None or not take_event(state, "user_interjection"):
return None return None
ctx = _control_context_messages( ctx = _msgs_for_interjection(state)
state,
extra_user=state.get("recent_interjection"),
)
observation = _maybe_observation(self.observation_provider) observation = _maybe_observation(self.observation_provider)
out = _generate_with_policy( out = _generate_with_policy(
self.policy, ctx, observation=observation, state=state, label="plan/say gen" self.policy, ctx, observation=observation, state=state, label="plan/say gen"
@@ -340,7 +363,17 @@ class UserInterjectionFwd(InferenceStep):
@dataclass @dataclass
class AskVQAFwd(InferenceStep): class AskVQAFwd(InferenceStep):
"""On stdin question, answer a frame-grounded VQA.""" """On stdin question, answer a frame-grounded VQA.
Mirrors the ``ask_vqa_*`` recipe layout exactly: a single user
turn carrying just the VQA question, plus the camera image block
in training (we drop the image at inference because the dataset's
image preprocessing doesn't match SmolVLM's vision tower input).
user: <question>
↓ generate ↓
assistant: <vqa answer>
"""
policy: Any = None policy: Any = None
observation_provider: Any = None observation_provider: Any = None
@@ -352,7 +385,7 @@ class AskVQAFwd(InferenceStep):
question = state.get("recent_vqa_query") question = state.get("recent_vqa_query")
if not question: if not question:
return None return None
ctx = _control_context_messages(state, extra_user=question) ctx = _msgs_for_vqa(question)
observation = _maybe_observation(self.observation_provider) observation = _maybe_observation(self.observation_provider)
answer = _generate_with_policy( answer = _generate_with_policy(
self.policy, ctx, observation=observation, state=state, label="vqa gen" self.policy, ctx, observation=observation, state=state, label="vqa gen"
@@ -426,6 +459,74 @@ def _control_context_messages(
return msgs return msgs
# ---------------------------------------------------------------------------
# Per-recipe prompt builders. Each one mirrors a single sub-recipe's
# message layout in ``smolvla2_hirobot.yaml`` so the chat-templated
# prompt at inference matches what the model saw during training.
# Generic ``_control_context_messages`` is kept around as a fallback
# for ad-hoc callers but the four high-level steps now use these.
# ---------------------------------------------------------------------------
def _msgs_for_subtask(state: dict[str, Any]) -> list[dict[str, Any]]:
"""``high_level_subtask`` recipe layout."""
head_parts = [state.get("task") or ""]
if state.get("current_plan"):
head_parts.append(f"Plan: {state['current_plan']}")
if state.get("current_memory"):
head_parts.append(f"Memory: {state['current_memory']}")
msgs: list[dict[str, Any]] = [
{"role": "user", "content": "\n".join(head_parts)}
]
if state.get("current_subtask"):
msgs.append(
{"role": "user", "content": f"Current subtask: {state['current_subtask']}"}
)
return msgs
def _msgs_for_memory(state: dict[str, Any]) -> list[dict[str, Any]]:
"""``memory_update`` recipe layout."""
msgs: list[dict[str, Any]] = [
{"role": "user", "content": state.get("task") or ""}
]
if state.get("current_memory"):
msgs.append(
{
"role": "assistant",
"content": f"Previous memory: {state['current_memory']}",
}
)
if state.get("current_subtask"):
msgs.append(
{
"role": "user",
"content": f"Completed subtask: {state['current_subtask']}",
}
)
return msgs
def _msgs_for_interjection(state: dict[str, Any]) -> list[dict[str, Any]]:
"""``user_interjection_response`` recipe layout."""
msgs: list[dict[str, Any]] = [
{"role": "user", "content": state.get("task") or ""}
]
if state.get("current_plan"):
msgs.append(
{"role": "assistant", "content": f"Previous plan:\n{state['current_plan']}"}
)
interjection = state.get("recent_interjection")
if interjection:
msgs.append({"role": "user", "content": interjection})
return msgs
def _msgs_for_vqa(question: str) -> list[dict[str, Any]]:
"""``ask_vqa_*`` recipe layout (text-only at inference)."""
return [{"role": "user", "content": question}]
def _maybe_observation(provider: Any) -> dict | None: def _maybe_observation(provider: Any) -> dict | None:
"""Pull one observation from ``provider`` if it's set, else ``None``. """Pull one observation from ``provider`` if it's set, else ``None``.