mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-25 13:40:00 +00:00
fix(smolvla2): tokenize lang prompt inline before select_action
LowLevelForward was handing the observation provider's output straight to ``policy.select_action``, but SmolVLA's ``_get_action_chunk`` indexes ``batch[OBS_LANGUAGE_TOKENS]`` and crashes with ``KeyError: 'observation.language.tokens'`` when the key isn't there. Our provider deliberately strips the dataset's language columns (the runtime drives messages itself), so nothing else was producing those tokens — the chunk path crashed on the very first tick after task was set. Build a low-level prompt from current runtime state inline (task / plan / memory as the user turn, current subtask appended as a continuation assistant turn when known), tokenize it with the same helper the high-level steps use, and merge ``lang_tokens`` / ``lang_masks`` into the observation before the call. Skip the step when no task is set yet, and swallow ``select_action`` exceptions at debug level so a missing observation feature doesn't kill the REPL. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -91,10 +91,35 @@ class LowLevelForward(InferenceStep):
|
||||
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||
if self.policy is None or self.observation_provider is None:
|
||||
return None
|
||||
if not state.get("task"):
|
||||
# No task yet → nothing useful to condition on.
|
||||
return None
|
||||
observation = self.observation_provider()
|
||||
if observation is None:
|
||||
return None
|
||||
action = self.policy.select_action(observation)
|
||||
# SmolVLA's ``select_action`` expects the full preprocessed
|
||||
# batch, including ``OBS_LANGUAGE_TOKENS`` /
|
||||
# ``OBS_LANGUAGE_ATTENTION_MASK``. The observation provider
|
||||
# only returns image / state features (the runtime drives
|
||||
# messages itself), so build a low-level prompt from current
|
||||
# runtime state and tokenize it inline.
|
||||
ctx = _control_context_messages(state)
|
||||
if state.get("current_subtask"):
|
||||
ctx = ctx + [{"role": "assistant", "content": state["current_subtask"]}]
|
||||
text_batch = _build_text_batch(self.policy, ctx)
|
||||
from lerobot.utils.constants import ( # noqa: PLC0415
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_TOKENS,
|
||||
)
|
||||
|
||||
observation = dict(observation)
|
||||
observation[OBS_LANGUAGE_TOKENS] = text_batch["lang_tokens"]
|
||||
observation[OBS_LANGUAGE_ATTENTION_MASK] = text_batch["lang_masks"]
|
||||
try:
|
||||
action = self.policy.select_action(observation)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.debug("select_action skipped: %s", exc)
|
||||
return None
|
||||
# SmolVLA returns a single action; if the underlying policy
|
||||
# streams chunks, split per-step here. For v1 we just enqueue
|
||||
# the result.
|
||||
|
||||
Reference in New Issue
Block a user