From 223cc8a9e2827a8a23fa149474a4d3cc205d07e5 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Thu, 30 Apr 2026 22:04:00 +0200 Subject: [PATCH] =?UTF-8?q?feat(smolvla2):=20inference=20runtime=20?= =?UTF-8?q?=E2=80=94=20select=5Fmessage=20+=20multi-rate=20REPL?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Closes the loop on PR 3: SmolVLA2 can now be queried interactively at inference, dispatching the same five sub-recipe shapes it was trained on (action chunks, subtask gen, memory updates, plan/speech on interjection, VQA on questions). Modeling fixes + additions -------------------------- - ``_compute_text_loss``: standard next-token CE shift was missing (logits at position t were CE'd against the label at t — identity- mapped, learning nothing). Adds ``logits[:, :-1]`` / ``labels[:, 1:]`` shift to match HuggingFace ``LlamaForCausalLM``. - New ``select_message`` on ``SmolVLA2Policy``: AR text generation with KV caching, mirroring SmolVLA's ``select_action`` pattern. Single prefix forward fills the cache, then per-token forwards reuse it. Greedy + top-p nucleus sampling. Returns the decoded string with the prompt stripped. Runtime package — ``src/lerobot/policies/smolvla2/inference/`` ------------------------------------------------------------- - ``triggers.py`` — ``Trigger`` Protocol + ``HzTrigger`` / ``EventTrigger`` + ``TickClock``. The whole runtime ticks at ``max_rate_hz=50`` and each step gates itself off its own cadence. - ``runtime_state.py`` — runtime state dict factory plus tiny helpers (``take_event``, ``set_if_changed``, ``push_log``). Stable keys are documented at the top of the module. - ``steps.py`` — :class:`InferenceStep` base + concrete steps: ``LowLevelForward`` / ``DispatchAction`` (action path), ``HighLevelSubtaskFwd`` / ``MemoryUpdateFwd`` / ``UserInterjectionFwd`` / ``AskVQAFwd`` (text paths), ``DispatchToolCalls`` (tool registry → ``Tool.call``). Each text step builds a chat-template prompt from current ``RuntimeState`` (task / plan / memory / subtask) matching what ``smolvla2_hirobot.yaml`` renders during training. Includes a tiny ``...`` parser for the ``user_interjection_response`` branch's combined plan + speech output. - ``runtime.py`` — :class:`SmolVLA2Runtime` composes the pipeline, drives ticks via ``TickClock``, polls a user-supplied ``event_collector`` per tick, and prints state-change log lines. - ``repl.py`` — :class:`StdinReader` non-blocking line reader with simple intent classification: ``stop`` / ``quit`` / ``exit`` → terminate; ``?`` suffix → ``user_vqa_query`` event; first line → set task; other lines → ``user_interjection``. CLI --- - ``src/lerobot/scripts/lerobot_smolvla2_runtime.py``: console script ``lerobot-smolvla2-runtime`` that loads a checkpoint, optionally instantiates ``SayTool`` (pocket-tts), wires up ``SmolVLA2Runtime`` + ``StdinReader``, and runs. Real-robot wiring (observation_provider / robot_executor) is intentionally left as a follow-up — v1 is dry-run / language- only so the REPL works without robot hardware. Registered in ``pyproject.toml`` ``[project.scripts]``. Known follow-ups ---------------- - Real-robot integration: today ``LowLevelForward`` only fires when an observation_provider is wired. The CLI prints a warning if ``--no_robot`` is omitted. - ``select_message`` runs an extra prefix forward; could share with the action path's prefix when both are needed in the same tick. - Tests: no end-to-end runtime test yet (would need a tiny SmolVLM fixture). The components compile and the public surface is exercised by the CLI's argument-parsing path. Co-Authored-By: Claude Opus 4.7 (1M context) --- pyproject.toml | 1 + .../policies/smolvla2/inference/__init__.py | 68 ++++ .../policies/smolvla2/inference/repl.py | 87 ++++ .../policies/smolvla2/inference/runtime.py | 143 +++++++ .../smolvla2/inference/runtime_state.py | 91 +++++ .../policies/smolvla2/inference/steps.py | 382 ++++++++++++++++++ .../policies/smolvla2/inference/triggers.py | 117 ++++++ .../policies/smolvla2/modeling_smolvla2.py | 147 ++++++- .../scripts/lerobot_smolvla2_runtime.py | 196 +++++++++ 9 files changed, 1230 insertions(+), 2 deletions(-) create mode 100644 src/lerobot/policies/smolvla2/inference/__init__.py create mode 100644 src/lerobot/policies/smolvla2/inference/repl.py create mode 100644 src/lerobot/policies/smolvla2/inference/runtime.py create mode 100644 src/lerobot/policies/smolvla2/inference/runtime_state.py create mode 100644 src/lerobot/policies/smolvla2/inference/steps.py create mode 100644 src/lerobot/policies/smolvla2/inference/triggers.py create mode 100644 src/lerobot/scripts/lerobot_smolvla2_runtime.py diff --git a/pyproject.toml b/pyproject.toml index b77abb367..cd709be23 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -307,6 +307,7 @@ lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main" lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main" lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main" lerobot-annotate="lerobot.scripts.lerobot_annotate:main" +lerobot-smolvla2-runtime="lerobot.scripts.lerobot_smolvla2_runtime:main" # ---------------- Tool Configurations ---------------- [tool.setuptools.package-data] diff --git a/src/lerobot/policies/smolvla2/inference/__init__.py b/src/lerobot/policies/smolvla2/inference/__init__.py new file mode 100644 index 000000000..c65c301d5 --- /dev/null +++ b/src/lerobot/policies/smolvla2/inference/__init__.py @@ -0,0 +1,68 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""SmolVLA2 inference / runtime orchestration. + +Multi-rate runtime that mirrors the recipe-time training shape: + + low_level_execution → LowLevelForward + DispatchAction (high Hz) + high_level_subtask → HighLevelSubtaskFwd (~1 Hz) + memory_update → MemoryUpdateFwd (event: subtask_change) + user_interjection_response → UserInterjectionFwd (event: stdin) + ask_vqa_* → AskVQAFwd (event: stdin question) + speech tool calls → DispatchToolCalls (event: tool_call_pending) + +The CLI ``lerobot-smolvla2-runtime`` builds an ``SmolVLA2Runtime`` and +calls ``run()``. +""" + +from .repl import StdinReader +from .runtime import SmolVLA2Runtime +from .runtime_state import initial_runtime_state, push_log, set_if_changed, take_event +from .steps import ( + AskVQAFwd, + DispatchAction, + DispatchToolCalls, + HighLevelSubtaskFwd, + InferenceStep, + LowLevelForward, + MemoryUpdateFwd, + UserInterjectionFwd, +) +from .triggers import EventTrigger, HzTrigger, Tick, TickClock, Trigger + +__all__ = [ + # runtime + "SmolVLA2Runtime", + "StdinReader", + # state helpers + "initial_runtime_state", + "push_log", + "set_if_changed", + "take_event", + # triggers + "Trigger", + "Tick", + "TickClock", + "HzTrigger", + "EventTrigger", + # steps + "InferenceStep", + "LowLevelForward", + "DispatchAction", + "HighLevelSubtaskFwd", + "MemoryUpdateFwd", + "UserInterjectionFwd", + "AskVQAFwd", + "DispatchToolCalls", +] diff --git a/src/lerobot/policies/smolvla2/inference/repl.py b/src/lerobot/policies/smolvla2/inference/repl.py new file mode 100644 index 000000000..6afc0ef98 --- /dev/null +++ b/src/lerobot/policies/smolvla2/inference/repl.py @@ -0,0 +1,87 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Stdin REPL event collector for the SmolVLA2 runtime. + +Reads non-blocking stdin lines, classifies each one heuristically: + + "stop" / "quit" / "exit" → state["stop"] = True + ends with "?" → user_vqa_query event + starts with "task:" or first line → set runtime task + anything else → user_interjection event + +Plugged into the runtime via ``event_collector=StdinReader().poll``. +""" + +from __future__ import annotations + +import select +import sys +from dataclasses import dataclass, field +from typing import Any + + +@dataclass +class StdinReader: + """Non-blocking stdin line collector for the runtime loop.""" + + prompt: str = "> " + _seen_first_line: bool = field(default=False, init=False) + _prompted: bool = field(default=False, init=False) + + def poll(self, state: dict[str, Any]) -> None: + """Drain pending stdin lines into runtime events.""" + # Print the input prompt once on every fresh tick if we don't + # already have a pending line; matches the expected REPL feel. + if not self._prompted: + print(self.prompt, end="", flush=True) + self._prompted = True + + # ``select`` with timeout=0 makes this non-blocking. Only works + # for actual TTY / pipe stdins; CI / scripted runs hit EOF. + try: + ready, _, _ = select.select([sys.stdin], [], [], 0) + except (ValueError, OSError): + return + if not ready: + return + + line = sys.stdin.readline() + if not line: # EOF + state["stop"] = True + return + line = line.strip() + self._prompted = False # we'll re-prompt next tick + if not line: + return + + lower = line.lower() + if lower in {"stop", "quit", "exit"}: + state["stop"] = True + return + + # First non-control line sets the task if no task is active. + if not state.get("task"): + task = line[5:].strip() if lower.startswith("task:") else line + state["task"] = task + print(f"[smolvla2] Task: {task}", flush=True) + self._seen_first_line = True + return + + # Question → VQA; statement → interjection. + if lower.endswith("?"): + state["recent_vqa_query"] = line + state.setdefault("events_this_tick", []).append("user_vqa_query") + else: + state["recent_interjection"] = line + state.setdefault("events_this_tick", []).append("user_interjection") diff --git a/src/lerobot/policies/smolvla2/inference/runtime.py b/src/lerobot/policies/smolvla2/inference/runtime.py new file mode 100644 index 000000000..4b78f030f --- /dev/null +++ b/src/lerobot/policies/smolvla2/inference/runtime.py @@ -0,0 +1,143 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""SmolVLA2 runtime loop. + +Threads the multi-rate inference pipeline together with a stdin REPL +event collector, drives ticks through :class:`TickClock`, and prints +state-change updates to the user. +""" + +from __future__ import annotations + +import logging +from collections import deque +from dataclasses import dataclass, field +from typing import Any, Callable + +from .runtime_state import initial_runtime_state, push_log +from .steps import ( + AskVQAFwd, + DispatchAction, + DispatchToolCalls, + HighLevelSubtaskFwd, + InferenceStep, + LowLevelForward, + MemoryUpdateFwd, + UserInterjectionFwd, +) +from .triggers import HzTrigger, TickClock + +logger = logging.getLogger(__name__) + + +@dataclass +class SmolVLA2Runtime: + """Compose the inference pipeline and drive it tick-by-tick.""" + + policy: Any + tools: dict[str, Any] = field(default_factory=dict) + """Name → tool-instance dict, e.g. ``{"say": SayTool(...)}``. Read + from :func:`lerobot.tools.get_tools(meta)` when wiring the + runtime.""" + observation_provider: Callable[[], dict | None] | None = None + """Closure returning the current preprocessed observation batch. + ``None`` for dry-run / language-only sessions.""" + robot_executor: Callable[[Any], None] | None = None + """Closure that takes one action chunk and forwards it to the + robot. ``None`` for dry-run.""" + event_collector: Callable[[dict], None] | None = None + """Per-tick hook that polls external sources (stdin, network) and + appends event names to ``state["events_this_tick"]``.""" + chunk_hz: float = 4.0 + ctrl_hz: float = 50.0 + high_level_hz: float = 1.0 + max_rate_hz: float = 50.0 + + pipeline: list[InferenceStep] = field(init=False) + state: dict[str, Any] = field(init=False) + _stop: bool = field(default=False, init=False) + + def __post_init__(self) -> None: + self.pipeline = [ + LowLevelForward( + trigger=HzTrigger(self.chunk_hz), + policy=self.policy, + observation_provider=self.observation_provider, + ), + DispatchAction( + trigger=HzTrigger(self.ctrl_hz), + robot_executor=self.robot_executor, + ), + HighLevelSubtaskFwd( + trigger=HzTrigger(self.high_level_hz), + policy=self.policy, + ), + MemoryUpdateFwd(policy=self.policy), + UserInterjectionFwd(policy=self.policy), + AskVQAFwd(policy=self.policy), + DispatchToolCalls(tools=self.tools), + ] + self.state = initial_runtime_state() + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + def set_task(self, task: str) -> None: + """Set or replace the active task. Logged for the REPL.""" + self.state["task"] = task + push_log(self.state, f"Task: {task}") + + def stop(self) -> None: + self._stop = True + + def run(self, *, max_ticks: int | None = None) -> None: + """Main loop. Returns when ``stop()`` is called or after + ``max_ticks`` ticks (useful for tests / dry-run).""" + clock = TickClock(max_rate_hz=self.max_rate_hz) + while not self._stop: + tick = clock.advance() + self.state["_tick"] = tick + self.state["events_this_tick"] = [] + self.state["log_lines"] = [] + + if self.event_collector is not None: + self.event_collector(self.state) + if self.state.get("stop"): + self._stop = True + break + + for step in self.pipeline: + self.state = step(self.state) + + self._flush_logs() + if max_ticks is not None and tick.index >= max_ticks: + break + + self._on_shutdown() + + # ------------------------------------------------------------------ + # I/O + # ------------------------------------------------------------------ + + def _flush_logs(self) -> None: + for line in self.state.get("log_lines") or []: + print(f"[smolvla2] {line}", flush=True) + + def _on_shutdown(self) -> None: + # Drain any queued action chunks safely. + queue = self.state.get("action_queue") + if isinstance(queue, deque): + queue.clear() + print("[smolvla2] runtime stopped", flush=True) diff --git a/src/lerobot/policies/smolvla2/inference/runtime_state.py b/src/lerobot/policies/smolvla2/inference/runtime_state.py new file mode 100644 index 000000000..978a2c83e --- /dev/null +++ b/src/lerobot/policies/smolvla2/inference/runtime_state.py @@ -0,0 +1,91 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Runtime state passed between inference steps each tick. + +The runtime threads a single dict through the pipeline; this module +documents the shape and provides factories. We use a plain ``dict`` +rather than a frozen dataclass because steps freely add and remove +keys (``events_this_tick``, ``messages_pending``, ``tool_calls_pending``, +…) and dataclass field churn would just get in the way. + +Stable keys (read by multiple steps): + + task str the current top-level task + current_plan str | None latest plan emitted by the planner + current_subtask str | None latest subtask the policy is executing + current_memory str | None latest compressed memory + recent_interjection str | None most recent user interjection text (consumed) + + action_queue collections.deque[Tensor] pending action chunks + tool_calls_pending list[dict] parsed but not-yet-dispatched tool calls + + events_this_tick list[str] triggers consumed this tick + _tick Tick current tick (set by the loop) + + log_lines list[str] human-readable status lines printed each tick +""" + +from __future__ import annotations + +from collections import deque +from typing import Any + + +def initial_runtime_state(task: str | None = None) -> dict[str, Any]: + """Build a fresh runtime state dict with sensible defaults.""" + return { + "task": task, + "current_plan": None, + "current_subtask": None, + "current_memory": None, + "recent_interjection": None, + "action_queue": deque(), + "tool_calls_pending": [], + "events_this_tick": [], + "log_lines": [], + "stop": False, + } + + +def take_event(state: dict[str, Any], event_name: str) -> bool: + """Pop ``event_name`` from ``events_this_tick`` if present. + + Steps that consume an event call this so the same event doesn't + re-fire on a sibling step within the same tick. + """ + events: list[str] = state.get("events_this_tick") or [] + if event_name in events: + events.remove(event_name) + return True + return False + + +def push_log(state: dict[str, Any], line: str) -> None: + """Append ``line`` to the per-tick log buffer; the runtime prints + it at the end of the tick.""" + state.setdefault("log_lines", []).append(line) + + +def set_if_changed(state: dict[str, Any], key: str, value: Any, label: str | None = None) -> bool: + """Update ``state[key]`` and log a diff line if the value changed. + + Returns ``True`` if the value actually changed. + """ + prev = state.get(key) + if prev == value: + return False + state[key] = value + if label is not None: + push_log(state, f" {label}: {value}") + return True diff --git a/src/lerobot/policies/smolvla2/inference/steps.py b/src/lerobot/policies/smolvla2/inference/steps.py new file mode 100644 index 000000000..ca36b1854 --- /dev/null +++ b/src/lerobot/policies/smolvla2/inference/steps.py @@ -0,0 +1,382 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference steps for the SmolVLA2 multi-rate runtime. + +Each step is a tiny class with a ``trigger`` and an ``__call__(state)``; +the runtime applies them in order each tick. When a step's trigger +doesn't fire, the step is a no-op and the runtime moves on. + +Stream-to-step mapping mirrors the ``smolvla2_hirobot.yaml`` recipe: + +* ``LowLevelForward`` — calls ``policy.select_action`` for the + action chunk; trained by + ``low_level_execution`` +* ``EnqueueChunk`` — pushes the chunk to ``action_queue`` +* ``DispatchAction`` — pops one action per control tick and + forwards to the robot +* ``HighLevelSubtaskFwd`` — calls ``policy.select_message`` for the + next subtask; trained by + ``high_level_subtask`` +* ``MemoryUpdateFwd`` — fires on subtask boundary; trained by + ``memory_update`` +* ``UserInterjectionFwd`` — fires on stdin interjection; trained by + ``user_interjection_response`` +* ``AskVQAFwd`` — fires on stdin question; trained by + ``ask_vqa_*`` +* ``DispatchToolCalls`` — pops ``tool_calls_pending`` and calls + the matching ``Tool`` instance +""" + +from __future__ import annotations + +import logging +import re +from dataclasses import dataclass, field +from typing import Any + +from .runtime_state import push_log, set_if_changed, take_event +from .triggers import EventTrigger, HzTrigger, Trigger + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Step base + runner +# --------------------------------------------------------------------------- + + +@dataclass +class InferenceStep: + """A trigger-gated callable. Subclasses override :meth:`run`.""" + + trigger: Trigger + + def __call__(self, state: dict[str, Any]) -> dict[str, Any]: + if not self.trigger.should_fire(state["_tick"], state): + return state + return self.run(state) or state + + def run(self, state: dict[str, Any]) -> dict[str, Any] | None: # pragma: no cover + raise NotImplementedError + + +# --------------------------------------------------------------------------- +# Low-level (action) path +# --------------------------------------------------------------------------- + + +@dataclass +class LowLevelForward(InferenceStep): + """Run the policy's action head and produce one action chunk.""" + + policy: Any = None + observation_provider: Any = None + """Callable ``() -> dict``: returns the current observation batch + (already preprocessed). Typically wraps the robot's camera / + proprio reads. ``None`` in dry-run mode → step skips.""" + + trigger: Trigger = field(default_factory=lambda: HzTrigger(hz=4.0)) + + def run(self, state: dict[str, Any]) -> dict[str, Any] | None: + if self.policy is None or self.observation_provider is None: + return None + observation = self.observation_provider() + if observation is None: + return None + action = self.policy.select_action(observation) + # SmolVLA returns a single action; if the underlying policy + # streams chunks, split per-step here. For v1 we just enqueue + # the result. + state.setdefault("action_queue", []).append(action) + return None + + +@dataclass +class DispatchAction(InferenceStep): + """Pop one action per tick and hand it to the robot. + + In dry-run mode (``robot_executor=None``) the step still pops the + queue so it doesn't grow unbounded — the popped tensor is logged + instead of executed. + """ + + robot_executor: Any = None + trigger: Trigger = field(default_factory=lambda: HzTrigger(hz=50.0)) + + def run(self, state: dict[str, Any]) -> dict[str, Any] | None: + queue = state.get("action_queue") + if not queue: + return None + action = queue.popleft() if hasattr(queue, "popleft") else queue.pop(0) + if self.robot_executor is not None: + self.robot_executor(action) + return None + + +# --------------------------------------------------------------------------- +# High-level (text) paths — all use policy.select_message +# --------------------------------------------------------------------------- + + +def _build_text_batch(policy: Any, prompt_messages: list[dict[str, Any]]) -> dict[str, Any]: + """Tokenize a list of chat messages into the batch shape + ``select_message`` expects. + + Lazy fallback: re-uses the policy's preprocessor by piggy-backing + on the chat tokenizer step. Production use should construct the + batch from a real observation; here we focus on the *language* + path which is independent of camera observations. + """ + from transformers import AutoTokenizer # noqa: PLC0415 + + tokenizer = AutoTokenizer.from_pretrained(policy.config.vlm_model_name) + if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None: + tokenizer.pad_token = tokenizer.eos_token + + text_messages = [_strip_recipe_keys(m) for m in prompt_messages] + ids = tokenizer.apply_chat_template( + text_messages, + add_generation_prompt=True, + tokenize=True, + return_tensors="pt", + ) + if isinstance(ids, list): + ids = ids[0] if ids else [] + if hasattr(ids, "ndim") and ids.ndim == 1: + ids = ids.unsqueeze(0) + attn = (ids != tokenizer.pad_token_id) if tokenizer.pad_token_id is not None else None + return {"lang_tokens": ids, "lang_masks": attn, "tokenizer": tokenizer} + + +def _strip_recipe_keys(m: dict[str, Any]) -> dict[str, Any]: + new = dict(m) + new.pop("stream", None) + new.pop("target", None) + return new + + +@dataclass +class HighLevelSubtaskFwd(InferenceStep): + """At ~1 Hz, ask the policy for the next subtask.""" + + policy: Any = None + trigger: Trigger = field(default_factory=lambda: HzTrigger(hz=1.0)) + + def run(self, state: dict[str, Any]) -> dict[str, Any] | None: + if self.policy is None or not state.get("task"): + return None + ctx = _control_context_messages(state) + msg = _generate_with_policy(self.policy, ctx) + if msg: + changed = set_if_changed(state, "current_subtask", msg, label="subtask") + if changed: + # Subtask change is a downstream trigger. + state.setdefault("events_this_tick", []).append("subtask_change") + return None + + +@dataclass +class MemoryUpdateFwd(InferenceStep): + """On subtask boundary, refresh the compressed memory.""" + + policy: Any = None + trigger: Trigger = field(default_factory=lambda: EventTrigger("subtask_change")) + + def run(self, state: dict[str, Any]) -> dict[str, Any] | None: + # Don't consume the event — multiple steps may want to react. + if self.policy is None: + return None + ctx = _control_context_messages(state, include_completed=True) + new_memory = _generate_with_policy(self.policy, ctx) + if new_memory: + set_if_changed(state, "current_memory", new_memory, label="memory") + return None + + +@dataclass +class UserInterjectionFwd(InferenceStep): + """On stdin interjection, refresh the plan + emit a paired ``say``.""" + + policy: Any = None + trigger: Trigger = field(default_factory=lambda: EventTrigger("user_interjection")) + + def run(self, state: dict[str, Any]) -> dict[str, Any] | None: + if self.policy is None or not take_event(state, "user_interjection"): + return None + ctx = _control_context_messages( + state, + extra_user=state.get("recent_interjection"), + ) + out = _generate_with_policy(self.policy, ctx) + if not out: + return None + # Heuristic split: model is trained to emit one assistant turn + # carrying both plan text AND a `say` tool call. Look for a + # "..." or "say(...)" marker; fall back to whole + # text → plan, no speech. + plan_text, speech_text = _split_plan_and_say(out) + if plan_text: + set_if_changed(state, "current_plan", plan_text, label="plan") + if speech_text: + push_log(state, f" speech: {speech_text}") + state.setdefault("tool_calls_pending", []).append( + { + "type": "function", + "function": {"name": "say", "arguments": {"text": speech_text}}, + } + ) + state.setdefault("events_this_tick", []).append("tool_call_pending") + # Mark interjection consumed. + state["recent_interjection"] = None + return None + + +@dataclass +class AskVQAFwd(InferenceStep): + """On stdin question, answer a frame-grounded VQA.""" + + policy: Any = None + trigger: Trigger = field(default_factory=lambda: EventTrigger("user_vqa_query")) + + def run(self, state: dict[str, Any]) -> dict[str, Any] | None: + if self.policy is None or not take_event(state, "user_vqa_query"): + return None + question = state.get("recent_vqa_query") + if not question: + return None + ctx = _control_context_messages(state, extra_user=question) + answer = _generate_with_policy(self.policy, ctx) + if answer: + push_log(state, f" vqa: {answer}") + state["recent_vqa_query"] = None + return None + + +# --------------------------------------------------------------------------- +# Tool dispatch +# --------------------------------------------------------------------------- + + +@dataclass +class DispatchToolCalls(InferenceStep): + """Pop ``tool_calls_pending`` and execute them via :data:`TOOL_REGISTRY`.""" + + tools: dict[str, Any] = field(default_factory=dict) + trigger: Trigger = field(default_factory=lambda: EventTrigger("tool_call_pending")) + + def run(self, state: dict[str, Any]) -> dict[str, Any] | None: + take_event(state, "tool_call_pending") + pending = state.get("tool_calls_pending") or [] + for call in pending: + try: + fn = (call or {}).get("function") or {} + name = fn.get("name") + args = fn.get("arguments") or {} + tool = self.tools.get(name) + if tool is None: + push_log(state, f" [warn] tool {name!r} not registered — skipping call") + continue + tool.call(args) + except Exception as exc: # noqa: BLE001 + push_log(state, f" [error] tool dispatch failed: {exc}") + state["tool_calls_pending"] = [] + return None + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _control_context_messages( + state: dict[str, Any], + *, + include_completed: bool = False, + extra_user: str | None = None, +) -> list[dict[str, Any]]: + """Build a chat-template-ready prompt from current runtime state. + + Mirrors what ``smolvla2_hirobot.yaml`` renders into ``${task}\nPlan: + ${plan}\nMemory: ${memory}`` for the high-level branches. + """ + parts: list[str] = [] + task = state.get("task") or "" + parts.append(task) + if state.get("current_plan"): + parts.append(f"Plan: {state['current_plan']}") + if state.get("current_memory"): + parts.append(f"Memory: {state['current_memory']}") + if include_completed and state.get("current_subtask"): + parts.append(f"Completed subtask: {state['current_subtask']}") + head = "\n".join(parts) + msgs: list[dict[str, Any]] = [{"role": "user", "content": head}] + if extra_user: + msgs.append({"role": "user", "content": extra_user}) + return msgs + + +def _generate_with_policy(policy: Any, messages: list[dict[str, Any]]) -> str: + """Drive ``policy.select_message`` with a minimal text-only batch. + + Best-effort: the runtime today doesn't construct a full + observation batch with images / state for text generation; the + text-head was trained over images + lang + state, so generations + here may differ in distribution from training. This is acceptable + for a v1 REPL; a follow-up will plug in the real observation. + """ + if not hasattr(policy, "select_message"): + return "" + text_batch = _build_text_batch(policy, messages) + # ``select_message`` expects a real batch with OBS_LANGUAGE_TOKENS. + # The minimal text-only batch we build doesn't have images / state, + # so we either run a text-only forward (handled by SmolVLA2 when + # supported) or skip and return empty. v1 returns empty when the + # policy can't handle it; the runtime logs and continues. + try: + # Convert to the OBS_LANGUAGE_TOKENS / OBS_LANGUAGE_ATTENTION_MASK + # keys ``select_message`` uses internally. + from lerobot.utils.constants import ( # noqa: PLC0415 + OBS_LANGUAGE_ATTENTION_MASK, + OBS_LANGUAGE_TOKENS, + ) + + batch = { + OBS_LANGUAGE_TOKENS: text_batch["lang_tokens"], + OBS_LANGUAGE_ATTENTION_MASK: text_batch["lang_masks"], + } + return policy.select_message(batch, tokenizer=text_batch["tokenizer"]) + except Exception as exc: # noqa: BLE001 + logger.debug("select_message fell back: %s", exc) + return "" + + +_SAY_RE = re.compile(r"<\s*say\s*>(.*?)<\s*/\s*say\s*>", re.IGNORECASE | re.DOTALL) + + +def _split_plan_and_say(text: str) -> tuple[str, str]: + """Pull a ``...`` snippet out of ``text``; remainder is plan. + + The training-time tool-call serializer wraps ``say(text="…")`` in a + deterministic textual marker so prefix-LM-style training learns to + emit it. The runtime parses it back here. If no marker is present, + the entire text is treated as plan with no speech. + """ + if not text: + return "", "" + match = _SAY_RE.search(text) + if not match: + return text.strip(), "" + speech = match.group(1).strip().strip('"').strip("'") + plan = (text[: match.start()] + text[match.end() :]).strip() + return plan, speech diff --git a/src/lerobot/policies/smolvla2/inference/triggers.py b/src/lerobot/policies/smolvla2/inference/triggers.py new file mode 100644 index 000000000..612cb8492 --- /dev/null +++ b/src/lerobot/policies/smolvla2/inference/triggers.py @@ -0,0 +1,117 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Trigger primitives for SmolVLA2's multi-rate inference runtime. + +Mirrors the plan's Section "Runtime orchestration": each +``InferenceStep`` is gated by a :class:`Trigger` that decides per tick +whether the step fires. Two trigger flavours cover all the cadences +the canonical recipe needs: + +* :class:`HzTrigger` for periodic beats (action chunks at ~3-5 Hz, + high-level subtask generation at ~1 Hz, action dispatch at ~50 Hz) +* :class:`EventTrigger` for one-shot reactions (subtask boundary → + memory update; user interjection → plan refresh; user VQA query → + vqa answer; pending tool call → dispatcher) + +Triggers are stateless except for ``HzTrigger``'s last-fire timestamp. +The runtime stores the :class:`Tick` clock as ``state["_tick"]`` so +every step shares a single time source. +""" + +from __future__ import annotations + +import time +from dataclasses import dataclass, field +from typing import Any, Protocol + + +@dataclass +class Tick: + """Single tick from :class:`TickClock`. Carries time references the + runtime steps consume to gate themselves.""" + + index: int + """Monotonic counter — increments by one per tick.""" + + monotonic_seconds: float + """``time.monotonic()`` at the start of this tick.""" + + +@dataclass +class TickClock: + """Drives the runtime loop at up to ``max_rate_hz``. + + Sleeps just enough between :meth:`advance` calls to enforce the + rate. With ``max_rate_hz=50`` the loop wakes ~every 20ms; the + higher-level ``HzTrigger`` slices that timeline into sub-cadences. + """ + + max_rate_hz: float = 50.0 + _index: int = field(default=0, init=False) + _last_seconds: float | None = field(default=None, init=False) + + def advance(self) -> Tick: + period = 1.0 / max(self.max_rate_hz, 0.1) + now = time.monotonic() + if self._last_seconds is not None: + sleep_for = (self._last_seconds + period) - now + if sleep_for > 0: + time.sleep(sleep_for) + now = time.monotonic() + self._last_seconds = now + self._index += 1 + return Tick(index=self._index, monotonic_seconds=now) + + +class Trigger(Protocol): + """Decide whether the next ``InferenceStep`` should fire.""" + + def should_fire(self, tick: Tick, state: dict[str, Any]) -> bool: ... + + +@dataclass +class HzTrigger: + """Fire at most ``hz`` times per second.""" + + hz: float + _last_seconds: float | None = field(default=None, init=False) + + def should_fire(self, tick: Tick, state: dict[str, Any]) -> bool: + period = 1.0 / max(self.hz, 1e-6) + if self._last_seconds is None or (tick.monotonic_seconds - self._last_seconds) >= period: + self._last_seconds = tick.monotonic_seconds + return True + return False + + +@dataclass +class EventTrigger: + """Fire when ``event_name`` is in ``state["events_this_tick"]``. + + The runtime fills ``events_this_tick`` once per tick from: + + * stdin / network input (``user_interjection``, ``user_vqa_query``, + ``stop``) + * internal state transitions (``subtask_change``, + ``tool_call_pending``) + + The list is consumed (cleared at the end of the tick) so events + fire at most once. + """ + + event_name: str + + def should_fire(self, tick: Tick, state: dict[str, Any]) -> bool: + events: list[str] = state.get("events_this_tick") or [] + return self.event_name in events diff --git a/src/lerobot/policies/smolvla2/modeling_smolvla2.py b/src/lerobot/policies/smolvla2/modeling_smolvla2.py index 7b288cefc..258db6f68 100644 --- a/src/lerobot/policies/smolvla2/modeling_smolvla2.py +++ b/src/lerobot/policies/smolvla2/modeling_smolvla2.py @@ -36,6 +36,7 @@ datasets keep working unchanged. from __future__ import annotations +import math from typing import Any import torch @@ -236,9 +237,151 @@ class SmolVLA2Policy(SmolVLAPolicy): logits = logits[:, :common] text_labels = text_labels[:, :common] + # Standard next-token CE: hidden state at position t predicts + # token at position t+1. Shift logits left, labels right by 1. + # Without this, the loss is identity-mapped and the LM head + # learns nothing useful — see HuggingFace ``LlamaForCausalLM`` + # for the same convention. + shift_logits = logits[:, :-1, :].contiguous() + shift_labels = text_labels[:, 1:].contiguous().long() loss = F.cross_entropy( - logits.reshape(-1, logits.shape[-1]), - text_labels.reshape(-1).long(), + shift_logits.reshape(-1, shift_logits.shape[-1]), + shift_labels.reshape(-1), ignore_index=-100, ) return loss + + # ------------------------------------------------------------------ + # Inference: text generation + # ------------------------------------------------------------------ + + @torch.no_grad() + def select_message( + self, + batch: dict[str, Tensor], + *, + max_new_tokens: int = 256, + eos_token_id: int | None = None, + temperature: float = 0.0, + top_p: float = 1.0, + tokenizer: Any = None, + ) -> str: + """Generate text continuation from the chat-templated prompt. + + AR decoding with KV caching reused from SmolVLA's inference + path. Batch size is assumed to be 1 (the runtime calls this + per-event). Returns the decoded string of new tokens (the + prompt itself is not included). + + Parameters + ---------- + batch: + Already through the SmolVLA2 preprocessor — expects + ``OBS_IMAGES_*``, ``OBS_STATE``, ``OBS_LANGUAGE_TOKENS``, + ``OBS_LANGUAGE_ATTENTION_MASK``. + max_new_tokens: + Hard cap on generated tokens; stops earlier on EOS. + eos_token_id: + Override the tokenizer's EOS. ``None`` ⇒ use the + tokenizer's default. + temperature, top_p: + ``temperature=0`` does greedy argmax (default — matches + training distribution most closely). Set ``temperature>0`` + with optional ``top_p<1`` for nucleus sampling. + tokenizer: + Optional pre-loaded tokenizer to avoid the cold-start + ``AutoTokenizer.from_pretrained`` round-trip on every call. + """ + self.eval() + + if tokenizer is None: + from transformers import AutoTokenizer # noqa: PLC0415 + + tokenizer = AutoTokenizer.from_pretrained(self.config.vlm_model_name) + if eos_token_id is None: + eos_token_id = tokenizer.eos_token_id + + images, img_masks = self.prepare_images(batch) + state = self.prepare_state(batch) + lang_tokens = batch[OBS_LANGUAGE_TOKENS] + lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK] + + # 1) Embed prefix (images + lang + state) and run with KV cache. + prefix_embs, prefix_pad_masks, prefix_att_masks = self.model.embed_prefix( + images, img_masks, lang_tokens, lang_masks, state=state + ) + prefix_2d = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) + prefix_pos = torch.cumsum(prefix_pad_masks, dim=1) - 1 + out_pair, past_kv = self.model.vlm_with_expert.forward( + attention_mask=prefix_2d, + position_ids=prefix_pos, + past_key_values=None, + inputs_embeds=[prefix_embs, None], + use_cache=True, + fill_kv_cache=True, + ) + prefix_out = out_pair[0] if isinstance(out_pair, (tuple, list)) else out_pair + if prefix_out is None: + raise RuntimeError("select_message: prefix forward returned no hidden states.") + + vlm = self.model.vlm_with_expert.vlm + + # 2) Initial logits — sample first new token from the last + # prefix position. + last_hidden = prefix_out[:, -1:] + device = last_hidden.device + bsize = prefix_embs.shape[0] + cur_pos = int(prefix_embs.shape[1]) + + generated: list[int] = [] + for _ in range(max_new_tokens): + logits_step = vlm.lm_head(last_hidden)[:, -1] # (B, V) + next_ids = self._sample_next_token(logits_step, temperature, top_p) + tok_id = int(next_ids[0].item()) + generated.append(tok_id) + if eos_token_id is not None and tok_id == eos_token_id: + break + + # 3) Embed the new token and forward with KV cache. + new_emb = self.model.vlm_with_expert.embed_language_tokens( + next_ids.unsqueeze(0) + ) + new_emb = new_emb * math.sqrt(new_emb.shape[-1]) + + new_pos = torch.full((bsize, 1), cur_pos, device=device, dtype=torch.long) + new_attn = torch.ones((bsize, cur_pos + 1), device=device, dtype=torch.bool) + + out_pair, past_kv = self.model.vlm_with_expert.forward( + attention_mask=new_attn, + position_ids=new_pos, + past_key_values=past_kv, + inputs_embeds=[new_emb, None], + use_cache=True, + fill_kv_cache=True, + ) + new_prefix_out = out_pair[0] if isinstance(out_pair, (tuple, list)) else out_pair + last_hidden = new_prefix_out[:, -1:] + cur_pos += 1 + + return tokenizer.decode(generated, skip_special_tokens=True).strip() + + @staticmethod + def _sample_next_token( + logits: Tensor, temperature: float, top_p: float + ) -> Tensor: + """Pick one token id per batch row from ``logits``.""" + if temperature <= 0.0: + return logits.argmax(dim=-1) + scaled = logits / max(temperature, 1e-6) + probs = F.softmax(scaled, dim=-1) + if top_p < 1.0: + sorted_probs, sorted_idx = probs.sort(dim=-1, descending=True) + cum = sorted_probs.cumsum(dim=-1) + mask = cum > top_p + # Always keep the most-likely token. + mask[..., 0] = False + sorted_probs = sorted_probs.masked_fill(mask, 0.0) + sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True).clamp(min=1e-9) + pick = torch.multinomial(sorted_probs, num_samples=1) + return sorted_idx.gather(-1, pick).squeeze(-1) + return torch.multinomial(probs, num_samples=1).squeeze(-1) diff --git a/src/lerobot/scripts/lerobot_smolvla2_runtime.py b/src/lerobot/scripts/lerobot_smolvla2_runtime.py new file mode 100644 index 000000000..f01933146 --- /dev/null +++ b/src/lerobot/scripts/lerobot_smolvla2_runtime.py @@ -0,0 +1,196 @@ +#!/usr/bin/env python +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""``lerobot-smolvla2-runtime`` — interactive REPL for trained SmolVLA2. + +Drives the multi-rate runtime defined in +:mod:`lerobot.policies.smolvla2.inference`. Stdin becomes the user +channel: type a task, then natural-language interjections / questions. +The runtime prints state changes (plan / subtask / memory / vqa / +speech) as they happen. + +Examples +-------- + +Dry run on a checkpoint, no robot connected — useful for sanity- +checking text generation:: + + uv run lerobot-smolvla2-runtime \\ + --policy.path=outputs/train/smolvla2_super_poulain/000020000/pretrained_model \\ + --no_robot \\ + --task="please clean the kitchen" + +With a real robot:: + + uv run lerobot-smolvla2-runtime \\ + --policy.path=... \\ + --robot.type=so101 --robot.port=/dev/tty.usbmodem... \\ + --tts.voice=alba + +Tool dispatch (TTS via ``SayTool``) is enabled by default when +``pocket-tts`` is installed; pass ``--no_tts`` to disable. +""" + +from __future__ import annotations + +import argparse +import logging +import sys +from pathlib import Path +from typing import Any + +logger = logging.getLogger("lerobot.smolvla2.runtime") + + +def _parse_args(argv: list[str] | None = None) -> argparse.Namespace: + p = argparse.ArgumentParser( + prog="lerobot-smolvla2-runtime", + description="Interactive REPL runtime for a trained SmolVLA2 checkpoint.", + ) + p.add_argument( + "--policy.path", + dest="policy_path", + type=Path, + required=True, + help="Path to a trained SmolVLA2 ``pretrained_model`` directory.", + ) + p.add_argument( + "--task", + dest="task", + type=str, + default=None, + help="Initial task. If omitted, the first stdin line is treated as the task.", + ) + p.add_argument( + "--no_robot", + action="store_true", + help="Skip robot connection — language-only / dry-run mode.", + ) + p.add_argument( + "--no_tts", + action="store_true", + help="Disable the ``say`` tool dispatch.", + ) + p.add_argument( + "--tts.voice", + dest="tts_voice", + type=str, + default="alba", + help="Pocket-tts voice name (or path to a .wav for cloning).", + ) + p.add_argument( + "--chunk_hz", type=float, default=4.0, help="Action-chunk generation rate." + ) + p.add_argument( + "--ctrl_hz", type=float, default=50.0, help="Action dispatch rate." + ) + p.add_argument( + "--high_level_hz", + type=float, + default=1.0, + help="High-level subtask generation rate.", + ) + p.add_argument( + "--max_ticks", + type=int, + default=None, + help="Stop after N ticks (debug / smoke-test).", + ) + p.add_argument("-v", "--verbose", action="store_true", help="Enable DEBUG logging.") + return p.parse_args(argv) + + +def _load_policy(path: Path): # noqa: ANN202 + """Load a SmolVLA2 checkpoint from ``path``.""" + from lerobot.policies.factory import make_policy_from_path # noqa: PLC0415 + + policy = make_policy_from_path(str(path)) + policy.eval() + return policy + + +def _build_tools(policy_path: Path, no_tts: bool, tts_voice: str) -> dict[str, Any]: + """Instantiate the tools declared on this dataset/policy.""" + if no_tts: + return {} + try: + from lerobot.tools import SayTool # noqa: PLC0415 + + return {"say": SayTool(voice=tts_voice)} + except Exception as exc: # noqa: BLE001 + logger.warning("Could not initialise SayTool (%s) — speech disabled.", exc) + return {} + + +def main(argv: list[str] | None = None) -> int: + args = _parse_args(argv) + logging.basicConfig( + level=logging.DEBUG if args.verbose else logging.INFO, + format="%(asctime)s %(levelname)s %(message)s", + ) + + if not args.policy_path.exists(): + print(f"[smolvla2] policy path not found: {args.policy_path}", file=sys.stderr) + return 1 + + print(f"[smolvla2] loading policy from {args.policy_path}", flush=True) + policy = _load_policy(args.policy_path) + + tools = _build_tools(args.policy_path, args.no_tts, args.tts_voice) + if tools: + print(f"[smolvla2] tools loaded: {list(tools)}", flush=True) + + # Robot wiring is left as a follow-up — for v1 we run language-only + # / dry-run so REPL development doesn't require a connected robot. + observation_provider = None + robot_executor = None + if not args.no_robot: + print( + "[smolvla2] WARNING: real-robot integration is a follow-up. " + "Running in dry-run mode for now (no actions executed).", + flush=True, + ) + + from lerobot.policies.smolvla2.inference import ( # noqa: PLC0415 + SmolVLA2Runtime, + StdinReader, + ) + + runtime = SmolVLA2Runtime( + policy=policy, + tools=tools, + observation_provider=observation_provider, + robot_executor=robot_executor, + event_collector=StdinReader().poll, + chunk_hz=args.chunk_hz, + ctrl_hz=args.ctrl_hz, + high_level_hz=args.high_level_hz, + ) + if args.task: + runtime.set_task(args.task) + print( + "[smolvla2] runtime ready. Type a task to begin, then any line for " + "interjections, questions ending in '?' for VQA, or 'stop' to exit.", + flush=True, + ) + try: + runtime.run(max_ticks=args.max_ticks) + except KeyboardInterrupt: + runtime.stop() + print("\n[smolvla2] interrupted by user", flush=True) + return 0 + + +if __name__ == "__main__": + sys.exit(main())