diff --git a/src/lerobot/policies/smolvla2/inference/steps.py b/src/lerobot/policies/smolvla2/inference/steps.py index 015d03504..d30bc9dc9 100644 --- a/src/lerobot/policies/smolvla2/inference/steps.py +++ b/src/lerobot/policies/smolvla2/inference/steps.py @@ -245,7 +245,15 @@ def _strip_recipe_keys(m: dict[str, Any]) -> dict[str, Any]: @dataclass class HighLevelSubtaskFwd(InferenceStep): - """At ~1 Hz, ask the policy for the next subtask.""" + """At ~1 Hz, ask the policy for the next subtask. + + Mirrors the ``high_level_subtask`` recipe layout exactly: + + user: "${task}\\nPlan: ${plan}\\nMemory: ${memory}" + user: "Current subtask: ${subtask}" (if subtask present) + ↓ generate ↓ + assistant: + """ policy: Any = None observation_provider: Any = None @@ -258,7 +266,7 @@ class HighLevelSubtaskFwd(InferenceStep): def run(self, state: dict[str, Any]) -> dict[str, Any] | None: if self.policy is None or not state.get("task"): return None - ctx = _control_context_messages(state) + ctx = _msgs_for_subtask(state) observation = _maybe_observation(self.observation_provider) msg = _generate_with_policy( self.policy, ctx, observation=observation, state=state, label="subtask gen" @@ -275,7 +283,16 @@ class HighLevelSubtaskFwd(InferenceStep): @dataclass class MemoryUpdateFwd(InferenceStep): - """On subtask boundary, refresh the compressed memory.""" + """On subtask boundary, refresh the compressed memory. + + Mirrors the ``memory_update`` recipe layout exactly: + + user: "${task}" + assistant: "Previous memory: ${prior_memory}" (if prior memory) + user: "Completed subtask: ${completed_subtask}" (if subtask) + ↓ generate ↓ + assistant: + """ policy: Any = None observation_provider: Any = None @@ -285,7 +302,7 @@ class MemoryUpdateFwd(InferenceStep): # Don't consume the event — multiple steps may want to react. if self.policy is None: return None - ctx = _control_context_messages(state, include_completed=True) + ctx = _msgs_for_memory(state) observation = _maybe_observation(self.observation_provider) new_memory = _generate_with_policy( self.policy, ctx, observation=observation, state=state, label="memory gen" @@ -297,7 +314,16 @@ class MemoryUpdateFwd(InferenceStep): @dataclass class UserInterjectionFwd(InferenceStep): - """On stdin interjection, refresh the plan + emit a paired ``say``.""" + """On stdin interjection, refresh the plan + emit a paired ``say``. + + Mirrors the ``user_interjection_response`` recipe layout exactly: + + user: "${task}" + assistant: "Previous plan:\\n${prior_plan}" (if prior plan) + user: "${interjection}" (the new utterance) + ↓ generate ↓ + assistant: ...> + """ policy: Any = None observation_provider: Any = None @@ -306,10 +332,7 @@ class UserInterjectionFwd(InferenceStep): def run(self, state: dict[str, Any]) -> dict[str, Any] | None: if self.policy is None or not take_event(state, "user_interjection"): return None - ctx = _control_context_messages( - state, - extra_user=state.get("recent_interjection"), - ) + ctx = _msgs_for_interjection(state) observation = _maybe_observation(self.observation_provider) out = _generate_with_policy( self.policy, ctx, observation=observation, state=state, label="plan/say gen" @@ -340,7 +363,17 @@ class UserInterjectionFwd(InferenceStep): @dataclass class AskVQAFwd(InferenceStep): - """On stdin question, answer a frame-grounded VQA.""" + """On stdin question, answer a frame-grounded VQA. + + Mirrors the ``ask_vqa_*`` recipe layout exactly: a single user + turn carrying just the VQA question, plus the camera image block + in training (we drop the image at inference because the dataset's + image preprocessing doesn't match SmolVLM's vision tower input). + + user: + ↓ generate ↓ + assistant: + """ policy: Any = None observation_provider: Any = None @@ -352,7 +385,7 @@ class AskVQAFwd(InferenceStep): question = state.get("recent_vqa_query") if not question: return None - ctx = _control_context_messages(state, extra_user=question) + ctx = _msgs_for_vqa(question) observation = _maybe_observation(self.observation_provider) answer = _generate_with_policy( self.policy, ctx, observation=observation, state=state, label="vqa gen" @@ -426,6 +459,74 @@ def _control_context_messages( return msgs +# --------------------------------------------------------------------------- +# Per-recipe prompt builders. Each one mirrors a single sub-recipe's +# message layout in ``smolvla2_hirobot.yaml`` so the chat-templated +# prompt at inference matches what the model saw during training. +# Generic ``_control_context_messages`` is kept around as a fallback +# for ad-hoc callers but the four high-level steps now use these. +# --------------------------------------------------------------------------- + + +def _msgs_for_subtask(state: dict[str, Any]) -> list[dict[str, Any]]: + """``high_level_subtask`` recipe layout.""" + head_parts = [state.get("task") or ""] + if state.get("current_plan"): + head_parts.append(f"Plan: {state['current_plan']}") + if state.get("current_memory"): + head_parts.append(f"Memory: {state['current_memory']}") + msgs: list[dict[str, Any]] = [ + {"role": "user", "content": "\n".join(head_parts)} + ] + if state.get("current_subtask"): + msgs.append( + {"role": "user", "content": f"Current subtask: {state['current_subtask']}"} + ) + return msgs + + +def _msgs_for_memory(state: dict[str, Any]) -> list[dict[str, Any]]: + """``memory_update`` recipe layout.""" + msgs: list[dict[str, Any]] = [ + {"role": "user", "content": state.get("task") or ""} + ] + if state.get("current_memory"): + msgs.append( + { + "role": "assistant", + "content": f"Previous memory: {state['current_memory']}", + } + ) + if state.get("current_subtask"): + msgs.append( + { + "role": "user", + "content": f"Completed subtask: {state['current_subtask']}", + } + ) + return msgs + + +def _msgs_for_interjection(state: dict[str, Any]) -> list[dict[str, Any]]: + """``user_interjection_response`` recipe layout.""" + msgs: list[dict[str, Any]] = [ + {"role": "user", "content": state.get("task") or ""} + ] + if state.get("current_plan"): + msgs.append( + {"role": "assistant", "content": f"Previous plan:\n{state['current_plan']}"} + ) + interjection = state.get("recent_interjection") + if interjection: + msgs.append({"role": "user", "content": interjection}) + return msgs + + +def _msgs_for_vqa(question: str) -> list[dict[str, Any]]: + """``ask_vqa_*`` recipe layout (text-only at inference).""" + return [{"role": "user", "content": question}] + + def _maybe_observation(provider: Any) -> dict | None: """Pull one observation from ``provider`` if it's set, else ``None``.