fix(smolvla2): tokenize lang prompt inline before select_action

LowLevelForward was handing the observation provider's output straight
to ``policy.select_action``, but SmolVLA's ``_get_action_chunk``
indexes ``batch[OBS_LANGUAGE_TOKENS]`` and crashes with ``KeyError:
'observation.language.tokens'`` when the key isn't there. Our provider
deliberately strips the dataset's language columns (the runtime drives
messages itself), so nothing else was producing those tokens — the
chunk path crashed on the very first tick after task was set.

Build a low-level prompt from current runtime state inline (task /
plan / memory as the user turn, current subtask appended as a
continuation assistant turn when known), tokenize it with the same
helper the high-level steps use, and merge ``lang_tokens`` /
``lang_masks`` into the observation before the call. Skip the step
when no task is set yet, and swallow ``select_action`` exceptions at
debug level so a missing observation feature doesn't kill the REPL.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-05 11:40:18 +02:00
parent fea41b29f5
commit 9cbbcfb6a2
@@ -91,10 +91,35 @@ class LowLevelForward(InferenceStep):
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
if self.policy is None or self.observation_provider is None:
return None
if not state.get("task"):
# No task yet → nothing useful to condition on.
return None
observation = self.observation_provider()
if observation is None:
return None
action = self.policy.select_action(observation)
# SmolVLA's ``select_action`` expects the full preprocessed
# batch, including ``OBS_LANGUAGE_TOKENS`` /
# ``OBS_LANGUAGE_ATTENTION_MASK``. The observation provider
# only returns image / state features (the runtime drives
# messages itself), so build a low-level prompt from current
# runtime state and tokenize it inline.
ctx = _control_context_messages(state)
if state.get("current_subtask"):
ctx = ctx + [{"role": "assistant", "content": state["current_subtask"]}]
text_batch = _build_text_batch(self.policy, ctx)
from lerobot.utils.constants import ( # noqa: PLC0415
OBS_LANGUAGE_ATTENTION_MASK,
OBS_LANGUAGE_TOKENS,
)
observation = dict(observation)
observation[OBS_LANGUAGE_TOKENS] = text_batch["lang_tokens"]
observation[OBS_LANGUAGE_ATTENTION_MASK] = text_batch["lang_masks"]
try:
action = self.policy.select_action(observation)
except Exception as exc: # noqa: BLE001
logger.debug("select_action skipped: %s", exc)
return None
# SmolVLA returns a single action; if the underlying policy
# streams chunks, split per-step here. For v1 we just enqueue
# the result.