mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 13:09:43 +00:00
fix(smolvla2): per-recipe inference prompts to match training shape
The four high-level steps shared one generic ``_control_context_messages`` that jammed task + plan + memory + completed_subtask into a single user message. The recipes in ``smolvla2_hirobot.yaml`` each have a *specific* multi-message layout (``memory_update``: ``user(task) → assistant(prev memory) → user(completed subtask)``; ``high_level_subtask``: ``user(task+plan+ memory) → user(current subtask)``; ``user_interjection_response``: ``user(task) → assistant(prev plan) → user(interjection)``). After ``apply_chat_template`` those layouts produce different prompts than the runtime's flattened single-user-turn version, and the model fell back to its dominant training mode (VQA JSON output) — generating ``":":":":":":...`` repetition. Add four per-recipe prompt builders (``_msgs_for_subtask``, ``_msgs_for_memory``, ``_msgs_for_interjection``, ``_msgs_for_vqa``), each mirroring its sub-recipe's exact message structure including the ``if_present`` skips. Wire each high-level step to its matching builder. Inference prompts now line up with what the model saw in training, so generation should produce coherent text instead of repeated tokens. Generic ``_control_context_messages`` is kept (still used by tests and the no-recipe fallback path). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -245,7 +245,15 @@ def _strip_recipe_keys(m: dict[str, Any]) -> dict[str, Any]:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class HighLevelSubtaskFwd(InferenceStep):
|
class HighLevelSubtaskFwd(InferenceStep):
|
||||||
"""At ~1 Hz, ask the policy for the next subtask."""
|
"""At ~1 Hz, ask the policy for the next subtask.
|
||||||
|
|
||||||
|
Mirrors the ``high_level_subtask`` recipe layout exactly:
|
||||||
|
|
||||||
|
user: "${task}\\nPlan: ${plan}\\nMemory: ${memory}"
|
||||||
|
user: "Current subtask: ${subtask}" (if subtask present)
|
||||||
|
↓ generate ↓
|
||||||
|
assistant: <next subtask>
|
||||||
|
"""
|
||||||
|
|
||||||
policy: Any = None
|
policy: Any = None
|
||||||
observation_provider: Any = None
|
observation_provider: Any = None
|
||||||
@@ -258,7 +266,7 @@ class HighLevelSubtaskFwd(InferenceStep):
|
|||||||
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||||
if self.policy is None or not state.get("task"):
|
if self.policy is None or not state.get("task"):
|
||||||
return None
|
return None
|
||||||
ctx = _control_context_messages(state)
|
ctx = _msgs_for_subtask(state)
|
||||||
observation = _maybe_observation(self.observation_provider)
|
observation = _maybe_observation(self.observation_provider)
|
||||||
msg = _generate_with_policy(
|
msg = _generate_with_policy(
|
||||||
self.policy, ctx, observation=observation, state=state, label="subtask gen"
|
self.policy, ctx, observation=observation, state=state, label="subtask gen"
|
||||||
@@ -275,7 +283,16 @@ class HighLevelSubtaskFwd(InferenceStep):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MemoryUpdateFwd(InferenceStep):
|
class MemoryUpdateFwd(InferenceStep):
|
||||||
"""On subtask boundary, refresh the compressed memory."""
|
"""On subtask boundary, refresh the compressed memory.
|
||||||
|
|
||||||
|
Mirrors the ``memory_update`` recipe layout exactly:
|
||||||
|
|
||||||
|
user: "${task}"
|
||||||
|
assistant: "Previous memory: ${prior_memory}" (if prior memory)
|
||||||
|
user: "Completed subtask: ${completed_subtask}" (if subtask)
|
||||||
|
↓ generate ↓
|
||||||
|
assistant: <new memory>
|
||||||
|
"""
|
||||||
|
|
||||||
policy: Any = None
|
policy: Any = None
|
||||||
observation_provider: Any = None
|
observation_provider: Any = None
|
||||||
@@ -285,7 +302,7 @@ class MemoryUpdateFwd(InferenceStep):
|
|||||||
# Don't consume the event — multiple steps may want to react.
|
# Don't consume the event — multiple steps may want to react.
|
||||||
if self.policy is None:
|
if self.policy is None:
|
||||||
return None
|
return None
|
||||||
ctx = _control_context_messages(state, include_completed=True)
|
ctx = _msgs_for_memory(state)
|
||||||
observation = _maybe_observation(self.observation_provider)
|
observation = _maybe_observation(self.observation_provider)
|
||||||
new_memory = _generate_with_policy(
|
new_memory = _generate_with_policy(
|
||||||
self.policy, ctx, observation=observation, state=state, label="memory gen"
|
self.policy, ctx, observation=observation, state=state, label="memory gen"
|
||||||
@@ -297,7 +314,16 @@ class MemoryUpdateFwd(InferenceStep):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class UserInterjectionFwd(InferenceStep):
|
class UserInterjectionFwd(InferenceStep):
|
||||||
"""On stdin interjection, refresh the plan + emit a paired ``say``."""
|
"""On stdin interjection, refresh the plan + emit a paired ``say``.
|
||||||
|
|
||||||
|
Mirrors the ``user_interjection_response`` recipe layout exactly:
|
||||||
|
|
||||||
|
user: "${task}"
|
||||||
|
assistant: "Previous plan:\\n${prior_plan}" (if prior plan)
|
||||||
|
user: "${interjection}" (the new utterance)
|
||||||
|
↓ generate ↓
|
||||||
|
assistant: <plan + <say>...</say>>
|
||||||
|
"""
|
||||||
|
|
||||||
policy: Any = None
|
policy: Any = None
|
||||||
observation_provider: Any = None
|
observation_provider: Any = None
|
||||||
@@ -306,10 +332,7 @@ class UserInterjectionFwd(InferenceStep):
|
|||||||
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||||
if self.policy is None or not take_event(state, "user_interjection"):
|
if self.policy is None or not take_event(state, "user_interjection"):
|
||||||
return None
|
return None
|
||||||
ctx = _control_context_messages(
|
ctx = _msgs_for_interjection(state)
|
||||||
state,
|
|
||||||
extra_user=state.get("recent_interjection"),
|
|
||||||
)
|
|
||||||
observation = _maybe_observation(self.observation_provider)
|
observation = _maybe_observation(self.observation_provider)
|
||||||
out = _generate_with_policy(
|
out = _generate_with_policy(
|
||||||
self.policy, ctx, observation=observation, state=state, label="plan/say gen"
|
self.policy, ctx, observation=observation, state=state, label="plan/say gen"
|
||||||
@@ -340,7 +363,17 @@ class UserInterjectionFwd(InferenceStep):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AskVQAFwd(InferenceStep):
|
class AskVQAFwd(InferenceStep):
|
||||||
"""On stdin question, answer a frame-grounded VQA."""
|
"""On stdin question, answer a frame-grounded VQA.
|
||||||
|
|
||||||
|
Mirrors the ``ask_vqa_*`` recipe layout exactly: a single user
|
||||||
|
turn carrying just the VQA question, plus the camera image block
|
||||||
|
in training (we drop the image at inference because the dataset's
|
||||||
|
image preprocessing doesn't match SmolVLM's vision tower input).
|
||||||
|
|
||||||
|
user: <question>
|
||||||
|
↓ generate ↓
|
||||||
|
assistant: <vqa answer>
|
||||||
|
"""
|
||||||
|
|
||||||
policy: Any = None
|
policy: Any = None
|
||||||
observation_provider: Any = None
|
observation_provider: Any = None
|
||||||
@@ -352,7 +385,7 @@ class AskVQAFwd(InferenceStep):
|
|||||||
question = state.get("recent_vqa_query")
|
question = state.get("recent_vqa_query")
|
||||||
if not question:
|
if not question:
|
||||||
return None
|
return None
|
||||||
ctx = _control_context_messages(state, extra_user=question)
|
ctx = _msgs_for_vqa(question)
|
||||||
observation = _maybe_observation(self.observation_provider)
|
observation = _maybe_observation(self.observation_provider)
|
||||||
answer = _generate_with_policy(
|
answer = _generate_with_policy(
|
||||||
self.policy, ctx, observation=observation, state=state, label="vqa gen"
|
self.policy, ctx, observation=observation, state=state, label="vqa gen"
|
||||||
@@ -426,6 +459,74 @@ def _control_context_messages(
|
|||||||
return msgs
|
return msgs
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Per-recipe prompt builders. Each one mirrors a single sub-recipe's
|
||||||
|
# message layout in ``smolvla2_hirobot.yaml`` so the chat-templated
|
||||||
|
# prompt at inference matches what the model saw during training.
|
||||||
|
# Generic ``_control_context_messages`` is kept around as a fallback
|
||||||
|
# for ad-hoc callers but the four high-level steps now use these.
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _msgs_for_subtask(state: dict[str, Any]) -> list[dict[str, Any]]:
|
||||||
|
"""``high_level_subtask`` recipe layout."""
|
||||||
|
head_parts = [state.get("task") or ""]
|
||||||
|
if state.get("current_plan"):
|
||||||
|
head_parts.append(f"Plan: {state['current_plan']}")
|
||||||
|
if state.get("current_memory"):
|
||||||
|
head_parts.append(f"Memory: {state['current_memory']}")
|
||||||
|
msgs: list[dict[str, Any]] = [
|
||||||
|
{"role": "user", "content": "\n".join(head_parts)}
|
||||||
|
]
|
||||||
|
if state.get("current_subtask"):
|
||||||
|
msgs.append(
|
||||||
|
{"role": "user", "content": f"Current subtask: {state['current_subtask']}"}
|
||||||
|
)
|
||||||
|
return msgs
|
||||||
|
|
||||||
|
|
||||||
|
def _msgs_for_memory(state: dict[str, Any]) -> list[dict[str, Any]]:
|
||||||
|
"""``memory_update`` recipe layout."""
|
||||||
|
msgs: list[dict[str, Any]] = [
|
||||||
|
{"role": "user", "content": state.get("task") or ""}
|
||||||
|
]
|
||||||
|
if state.get("current_memory"):
|
||||||
|
msgs.append(
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": f"Previous memory: {state['current_memory']}",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if state.get("current_subtask"):
|
||||||
|
msgs.append(
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"Completed subtask: {state['current_subtask']}",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return msgs
|
||||||
|
|
||||||
|
|
||||||
|
def _msgs_for_interjection(state: dict[str, Any]) -> list[dict[str, Any]]:
|
||||||
|
"""``user_interjection_response`` recipe layout."""
|
||||||
|
msgs: list[dict[str, Any]] = [
|
||||||
|
{"role": "user", "content": state.get("task") or ""}
|
||||||
|
]
|
||||||
|
if state.get("current_plan"):
|
||||||
|
msgs.append(
|
||||||
|
{"role": "assistant", "content": f"Previous plan:\n{state['current_plan']}"}
|
||||||
|
)
|
||||||
|
interjection = state.get("recent_interjection")
|
||||||
|
if interjection:
|
||||||
|
msgs.append({"role": "user", "content": interjection})
|
||||||
|
return msgs
|
||||||
|
|
||||||
|
|
||||||
|
def _msgs_for_vqa(question: str) -> list[dict[str, Any]]:
|
||||||
|
"""``ask_vqa_*`` recipe layout (text-only at inference)."""
|
||||||
|
return [{"role": "user", "content": question}]
|
||||||
|
|
||||||
|
|
||||||
def _maybe_observation(provider: Any) -> dict | None:
|
def _maybe_observation(provider: Any) -> dict | None:
|
||||||
"""Pull one observation from ``provider`` if it's set, else ``None``.
|
"""Pull one observation from ``provider`` if it's set, else ``None``.
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user