mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-27 22:49:48 +00:00
fix(smolvla2): tokenize lang prompt inline before select_action
LowLevelForward was handing the observation provider's output straight to ``policy.select_action``, but SmolVLA's ``_get_action_chunk`` indexes ``batch[OBS_LANGUAGE_TOKENS]`` and crashes with ``KeyError: 'observation.language.tokens'`` when the key isn't there. Our provider deliberately strips the dataset's language columns (the runtime drives messages itself), so nothing else was producing those tokens — the chunk path crashed on the very first tick after task was set. Build a low-level prompt from current runtime state inline (task / plan / memory as the user turn, current subtask appended as a continuation assistant turn when known), tokenize it with the same helper the high-level steps use, and merge ``lang_tokens`` / ``lang_masks`` into the observation before the call. Skip the step when no task is set yet, and swallow ``select_action`` exceptions at debug level so a missing observation feature doesn't kill the REPL. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -91,10 +91,35 @@ class LowLevelForward(InferenceStep):
|
|||||||
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||||
if self.policy is None or self.observation_provider is None:
|
if self.policy is None or self.observation_provider is None:
|
||||||
return None
|
return None
|
||||||
|
if not state.get("task"):
|
||||||
|
# No task yet → nothing useful to condition on.
|
||||||
|
return None
|
||||||
observation = self.observation_provider()
|
observation = self.observation_provider()
|
||||||
if observation is None:
|
if observation is None:
|
||||||
return None
|
return None
|
||||||
|
# SmolVLA's ``select_action`` expects the full preprocessed
|
||||||
|
# batch, including ``OBS_LANGUAGE_TOKENS`` /
|
||||||
|
# ``OBS_LANGUAGE_ATTENTION_MASK``. The observation provider
|
||||||
|
# only returns image / state features (the runtime drives
|
||||||
|
# messages itself), so build a low-level prompt from current
|
||||||
|
# runtime state and tokenize it inline.
|
||||||
|
ctx = _control_context_messages(state)
|
||||||
|
if state.get("current_subtask"):
|
||||||
|
ctx = ctx + [{"role": "assistant", "content": state["current_subtask"]}]
|
||||||
|
text_batch = _build_text_batch(self.policy, ctx)
|
||||||
|
from lerobot.utils.constants import ( # noqa: PLC0415
|
||||||
|
OBS_LANGUAGE_ATTENTION_MASK,
|
||||||
|
OBS_LANGUAGE_TOKENS,
|
||||||
|
)
|
||||||
|
|
||||||
|
observation = dict(observation)
|
||||||
|
observation[OBS_LANGUAGE_TOKENS] = text_batch["lang_tokens"]
|
||||||
|
observation[OBS_LANGUAGE_ATTENTION_MASK] = text_batch["lang_masks"]
|
||||||
|
try:
|
||||||
action = self.policy.select_action(observation)
|
action = self.policy.select_action(observation)
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
logger.debug("select_action skipped: %s", exc)
|
||||||
|
return None
|
||||||
# SmolVLA returns a single action; if the underlying policy
|
# SmolVLA returns a single action; if the underlying policy
|
||||||
# streams chunks, split per-step here. For v1 we just enqueue
|
# streams chunks, split per-step here. For v1 we just enqueue
|
||||||
# the result.
|
# the result.
|
||||||
|
|||||||
Reference in New Issue
Block a user