mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 04:59:47 +00:00
fix(smolvla2): per-recipe inference prompts to match training shape
The four high-level steps shared one generic ``_control_context_messages`` that jammed task + plan + memory + completed_subtask into a single user message. The recipes in ``smolvla2_hirobot.yaml`` each have a *specific* multi-message layout (``memory_update``: ``user(task) → assistant(prev memory) → user(completed subtask)``; ``high_level_subtask``: ``user(task+plan+ memory) → user(current subtask)``; ``user_interjection_response``: ``user(task) → assistant(prev plan) → user(interjection)``). After ``apply_chat_template`` those layouts produce different prompts than the runtime's flattened single-user-turn version, and the model fell back to its dominant training mode (VQA JSON output) — generating ``":":":":":":...`` repetition. Add four per-recipe prompt builders (``_msgs_for_subtask``, ``_msgs_for_memory``, ``_msgs_for_interjection``, ``_msgs_for_vqa``), each mirroring its sub-recipe's exact message structure including the ``if_present`` skips. Wire each high-level step to its matching builder. Inference prompts now line up with what the model saw in training, so generation should produce coherent text instead of repeated tokens. Generic ``_control_context_messages`` is kept (still used by tests and the no-recipe fallback path). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -245,7 +245,15 @@ def _strip_recipe_keys(m: dict[str, Any]) -> dict[str, Any]:
|
||||
|
||||
@dataclass
|
||||
class HighLevelSubtaskFwd(InferenceStep):
|
||||
"""At ~1 Hz, ask the policy for the next subtask."""
|
||||
"""At ~1 Hz, ask the policy for the next subtask.
|
||||
|
||||
Mirrors the ``high_level_subtask`` recipe layout exactly:
|
||||
|
||||
user: "${task}\\nPlan: ${plan}\\nMemory: ${memory}"
|
||||
user: "Current subtask: ${subtask}" (if subtask present)
|
||||
↓ generate ↓
|
||||
assistant: <next subtask>
|
||||
"""
|
||||
|
||||
policy: Any = None
|
||||
observation_provider: Any = None
|
||||
@@ -258,7 +266,7 @@ class HighLevelSubtaskFwd(InferenceStep):
|
||||
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||
if self.policy is None or not state.get("task"):
|
||||
return None
|
||||
ctx = _control_context_messages(state)
|
||||
ctx = _msgs_for_subtask(state)
|
||||
observation = _maybe_observation(self.observation_provider)
|
||||
msg = _generate_with_policy(
|
||||
self.policy, ctx, observation=observation, state=state, label="subtask gen"
|
||||
@@ -275,7 +283,16 @@ class HighLevelSubtaskFwd(InferenceStep):
|
||||
|
||||
@dataclass
|
||||
class MemoryUpdateFwd(InferenceStep):
|
||||
"""On subtask boundary, refresh the compressed memory."""
|
||||
"""On subtask boundary, refresh the compressed memory.
|
||||
|
||||
Mirrors the ``memory_update`` recipe layout exactly:
|
||||
|
||||
user: "${task}"
|
||||
assistant: "Previous memory: ${prior_memory}" (if prior memory)
|
||||
user: "Completed subtask: ${completed_subtask}" (if subtask)
|
||||
↓ generate ↓
|
||||
assistant: <new memory>
|
||||
"""
|
||||
|
||||
policy: Any = None
|
||||
observation_provider: Any = None
|
||||
@@ -285,7 +302,7 @@ class MemoryUpdateFwd(InferenceStep):
|
||||
# Don't consume the event — multiple steps may want to react.
|
||||
if self.policy is None:
|
||||
return None
|
||||
ctx = _control_context_messages(state, include_completed=True)
|
||||
ctx = _msgs_for_memory(state)
|
||||
observation = _maybe_observation(self.observation_provider)
|
||||
new_memory = _generate_with_policy(
|
||||
self.policy, ctx, observation=observation, state=state, label="memory gen"
|
||||
@@ -297,7 +314,16 @@ class MemoryUpdateFwd(InferenceStep):
|
||||
|
||||
@dataclass
|
||||
class UserInterjectionFwd(InferenceStep):
|
||||
"""On stdin interjection, refresh the plan + emit a paired ``say``."""
|
||||
"""On stdin interjection, refresh the plan + emit a paired ``say``.
|
||||
|
||||
Mirrors the ``user_interjection_response`` recipe layout exactly:
|
||||
|
||||
user: "${task}"
|
||||
assistant: "Previous plan:\\n${prior_plan}" (if prior plan)
|
||||
user: "${interjection}" (the new utterance)
|
||||
↓ generate ↓
|
||||
assistant: <plan + <say>...</say>>
|
||||
"""
|
||||
|
||||
policy: Any = None
|
||||
observation_provider: Any = None
|
||||
@@ -306,10 +332,7 @@ class UserInterjectionFwd(InferenceStep):
|
||||
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||
if self.policy is None or not take_event(state, "user_interjection"):
|
||||
return None
|
||||
ctx = _control_context_messages(
|
||||
state,
|
||||
extra_user=state.get("recent_interjection"),
|
||||
)
|
||||
ctx = _msgs_for_interjection(state)
|
||||
observation = _maybe_observation(self.observation_provider)
|
||||
out = _generate_with_policy(
|
||||
self.policy, ctx, observation=observation, state=state, label="plan/say gen"
|
||||
@@ -340,7 +363,17 @@ class UserInterjectionFwd(InferenceStep):
|
||||
|
||||
@dataclass
|
||||
class AskVQAFwd(InferenceStep):
|
||||
"""On stdin question, answer a frame-grounded VQA."""
|
||||
"""On stdin question, answer a frame-grounded VQA.
|
||||
|
||||
Mirrors the ``ask_vqa_*`` recipe layout exactly: a single user
|
||||
turn carrying just the VQA question, plus the camera image block
|
||||
in training (we drop the image at inference because the dataset's
|
||||
image preprocessing doesn't match SmolVLM's vision tower input).
|
||||
|
||||
user: <question>
|
||||
↓ generate ↓
|
||||
assistant: <vqa answer>
|
||||
"""
|
||||
|
||||
policy: Any = None
|
||||
observation_provider: Any = None
|
||||
@@ -352,7 +385,7 @@ class AskVQAFwd(InferenceStep):
|
||||
question = state.get("recent_vqa_query")
|
||||
if not question:
|
||||
return None
|
||||
ctx = _control_context_messages(state, extra_user=question)
|
||||
ctx = _msgs_for_vqa(question)
|
||||
observation = _maybe_observation(self.observation_provider)
|
||||
answer = _generate_with_policy(
|
||||
self.policy, ctx, observation=observation, state=state, label="vqa gen"
|
||||
@@ -426,6 +459,74 @@ def _control_context_messages(
|
||||
return msgs
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-recipe prompt builders. Each one mirrors a single sub-recipe's
|
||||
# message layout in ``smolvla2_hirobot.yaml`` so the chat-templated
|
||||
# prompt at inference matches what the model saw during training.
|
||||
# Generic ``_control_context_messages`` is kept around as a fallback
|
||||
# for ad-hoc callers but the four high-level steps now use these.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _msgs_for_subtask(state: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""``high_level_subtask`` recipe layout."""
|
||||
head_parts = [state.get("task") or ""]
|
||||
if state.get("current_plan"):
|
||||
head_parts.append(f"Plan: {state['current_plan']}")
|
||||
if state.get("current_memory"):
|
||||
head_parts.append(f"Memory: {state['current_memory']}")
|
||||
msgs: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": "\n".join(head_parts)}
|
||||
]
|
||||
if state.get("current_subtask"):
|
||||
msgs.append(
|
||||
{"role": "user", "content": f"Current subtask: {state['current_subtask']}"}
|
||||
)
|
||||
return msgs
|
||||
|
||||
|
||||
def _msgs_for_memory(state: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""``memory_update`` recipe layout."""
|
||||
msgs: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": state.get("task") or ""}
|
||||
]
|
||||
if state.get("current_memory"):
|
||||
msgs.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"Previous memory: {state['current_memory']}",
|
||||
}
|
||||
)
|
||||
if state.get("current_subtask"):
|
||||
msgs.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Completed subtask: {state['current_subtask']}",
|
||||
}
|
||||
)
|
||||
return msgs
|
||||
|
||||
|
||||
def _msgs_for_interjection(state: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""``user_interjection_response`` recipe layout."""
|
||||
msgs: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": state.get("task") or ""}
|
||||
]
|
||||
if state.get("current_plan"):
|
||||
msgs.append(
|
||||
{"role": "assistant", "content": f"Previous plan:\n{state['current_plan']}"}
|
||||
)
|
||||
interjection = state.get("recent_interjection")
|
||||
if interjection:
|
||||
msgs.append({"role": "user", "content": interjection})
|
||||
return msgs
|
||||
|
||||
|
||||
def _msgs_for_vqa(question: str) -> list[dict[str, Any]]:
|
||||
"""``ask_vqa_*`` recipe layout (text-only at inference)."""
|
||||
return [{"role": "user", "content": question}]
|
||||
|
||||
|
||||
def _maybe_observation(provider: Any) -> dict | None:
|
||||
"""Pull one observation from ``provider`` if it's set, else ``None``.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user