Compare commits

...

2 Commits

Author SHA1 Message Date
Pepijn 223cc8a9e2 feat(smolvla2): inference runtime — select_message + multi-rate REPL
Closes the loop on PR 3: SmolVLA2 can now be queried interactively at
inference, dispatching the same five sub-recipe shapes it was trained
on (action chunks, subtask gen, memory updates, plan/speech on
interjection, VQA on questions).

Modeling fixes + additions
--------------------------

- ``_compute_text_loss``: standard next-token CE shift was missing
  (logits at position t were CE'd against the label at t — identity-
  mapped, learning nothing). Adds ``logits[:, :-1]`` /
  ``labels[:, 1:]`` shift to match HuggingFace ``LlamaForCausalLM``.

- New ``select_message`` on ``SmolVLA2Policy``: AR text generation
  with KV caching, mirroring SmolVLA's ``select_action`` pattern.
  Single prefix forward fills the cache, then per-token forwards
  reuse it. Greedy + top-p nucleus sampling. Returns the decoded
  string with the prompt stripped.

Runtime package — ``src/lerobot/policies/smolvla2/inference/``
-------------------------------------------------------------

- ``triggers.py`` — ``Trigger`` Protocol + ``HzTrigger`` /
  ``EventTrigger`` + ``TickClock``. The whole runtime ticks at
  ``max_rate_hz=50`` and each step gates itself off its own
  cadence.

- ``runtime_state.py`` — runtime state dict factory plus tiny
  helpers (``take_event``, ``set_if_changed``, ``push_log``).
  Stable keys are documented at the top of the module.

- ``steps.py`` — :class:`InferenceStep` base + concrete steps:
  ``LowLevelForward`` / ``DispatchAction`` (action path),
  ``HighLevelSubtaskFwd`` / ``MemoryUpdateFwd`` /
  ``UserInterjectionFwd`` / ``AskVQAFwd`` (text paths),
  ``DispatchToolCalls`` (tool registry → ``Tool.call``). Each
  text step builds a chat-template prompt from current
  ``RuntimeState`` (task / plan / memory / subtask) matching
  what ``smolvla2_hirobot.yaml`` renders during training.
  Includes a tiny ``<say>...</say>`` parser for the
  ``user_interjection_response`` branch's combined plan + speech
  output.

- ``runtime.py`` — :class:`SmolVLA2Runtime` composes the pipeline,
  drives ticks via ``TickClock``, polls a user-supplied
  ``event_collector`` per tick, and prints state-change log lines.

- ``repl.py`` — :class:`StdinReader` non-blocking line reader
  with simple intent classification: ``stop`` / ``quit`` /
  ``exit`` → terminate; ``?`` suffix → ``user_vqa_query`` event;
  first line → set task; other lines → ``user_interjection``.

CLI
---

- ``src/lerobot/scripts/lerobot_smolvla2_runtime.py``: console
  script ``lerobot-smolvla2-runtime`` that loads a checkpoint,
  optionally instantiates ``SayTool`` (pocket-tts), wires up
  ``SmolVLA2Runtime`` + ``StdinReader``, and runs.

  Real-robot wiring (observation_provider / robot_executor) is
  intentionally left as a follow-up — v1 is dry-run / language-
  only so the REPL works without robot hardware.

  Registered in ``pyproject.toml`` ``[project.scripts]``.

Known follow-ups
----------------

- Real-robot integration: today ``LowLevelForward`` only fires when
  an observation_provider is wired. The CLI prints a warning if
  ``--no_robot`` is omitted.
- ``select_message`` runs an extra prefix forward; could share with
  the action path's prefix when both are needed in the same tick.
- Tests: no end-to-end runtime test yet (would need a tiny SmolVLM
  fixture). The components compile and the public surface is
  exercised by the CLI's argument-parsing path.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 22:04:00 +02:00
Pepijn af6d8ebd5b feat(smolvla2): dual-head forward — flow loss + lm_head text loss
The third and final commit of PR 3's SmolVLA2 work. Wires the actual
training signal through:

* ``predict_actions[i] = True``  → sample i contributes to flow loss
* ``text_labels[i, t] != -100``  → token t of sample i contributes to
                                    LM-head cross-entropy

Both routing knobs come from ``SmolVLA2ChatTokenizerStep`` (previous
commit on this branch), which builds them from the recipe's
``message_streams`` / ``target_message_indices``. The per-sample
``predict_actions`` mask preserves the Pi0.5 convention from the
plan's Section I.7: "True iff any low_level target exists".

Implementation:

- ``forward`` reads ``text_labels`` and ``predict_actions`` from the
  batch. When neither is present (vanilla SmolVLA usage with no
  recipe), delegates to ``SmolVLAPolicy.forward`` so unannotated
  datasets keep training as before — full backward compatibility.
- ``flow_loss``: super().forward(reduction="none") returns the
  per-sample (B,) flow loss; we mask non-action samples with the
  ``predict_actions`` bool and renormalize by the count of action
  samples. ``flow_loss_weight = 0`` in the config disables this
  branch entirely (text-only training).
- ``text_loss``: a prefix-only forward through the VLM (no action
  expert / suffix), slicing the lang-token range out of the
  resulting hidden states (``embed_prefix`` orders the prefix as
  ``[image_blocks..., lang, state]`` so the slice is unambiguous).
  Apply ``vlm.lm_head`` to those hidden states, cross-entropy with
  ``text_labels`` (ignore_index=-100). ``text_loss_weight = 0``
  disables this branch (reverts to flow-only behaviour, matching
  SmolVLA exactly).
- The two losses are summed with the config-supplied weights.

Mixed-stream samples (one batch containing both action targets and
text-only sub-recipes) are handled correctly: each sample contributes
where its labels are valid and is masked elsewhere.

Limitations / known follow-ups:

- Text loss runs an additional prefix-only forward separate from the
  flow path's prefix forward. The forwards could share their prefix
  computation; for clarity of this first commit they don't.
  Optimization is straightforward when needed.
- Per-sample loss for ``reduction="none"`` is not yet meaningfully
  defined for the dual path — we broadcast the scalar to (B,) for
  caller compatibility (e.g. RA-BC weighting will need follow-up).
- Inference ``select_action`` is unchanged from SmolVLA today —
  it predicts actions only. A separate "generate text"
  ``select_message`` path is the natural next step for runtime
  use of the LM head (memory updates, plan refreshes, VQA answers).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 19:54:57 +02:00
9 changed files with 1401 additions and 48 deletions
+1
View File
@@ -307,6 +307,7 @@ lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
lerobot-annotate="lerobot.scripts.lerobot_annotate:main"
lerobot-smolvla2-runtime="lerobot.scripts.lerobot_smolvla2_runtime:main"
# ---------------- Tool Configurations ----------------
[tool.setuptools.package-data]
@@ -0,0 +1,68 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""SmolVLA2 inference / runtime orchestration.
Multi-rate runtime that mirrors the recipe-time training shape:
low_level_execution → LowLevelForward + DispatchAction (high Hz)
high_level_subtask → HighLevelSubtaskFwd (~1 Hz)
memory_update → MemoryUpdateFwd (event: subtask_change)
user_interjection_response → UserInterjectionFwd (event: stdin)
ask_vqa_* → AskVQAFwd (event: stdin question)
speech tool calls → DispatchToolCalls (event: tool_call_pending)
The CLI ``lerobot-smolvla2-runtime`` builds an ``SmolVLA2Runtime`` and
calls ``run()``.
"""
from .repl import StdinReader
from .runtime import SmolVLA2Runtime
from .runtime_state import initial_runtime_state, push_log, set_if_changed, take_event
from .steps import (
AskVQAFwd,
DispatchAction,
DispatchToolCalls,
HighLevelSubtaskFwd,
InferenceStep,
LowLevelForward,
MemoryUpdateFwd,
UserInterjectionFwd,
)
from .triggers import EventTrigger, HzTrigger, Tick, TickClock, Trigger
__all__ = [
# runtime
"SmolVLA2Runtime",
"StdinReader",
# state helpers
"initial_runtime_state",
"push_log",
"set_if_changed",
"take_event",
# triggers
"Trigger",
"Tick",
"TickClock",
"HzTrigger",
"EventTrigger",
# steps
"InferenceStep",
"LowLevelForward",
"DispatchAction",
"HighLevelSubtaskFwd",
"MemoryUpdateFwd",
"UserInterjectionFwd",
"AskVQAFwd",
"DispatchToolCalls",
]
@@ -0,0 +1,87 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Stdin REPL event collector for the SmolVLA2 runtime.
Reads non-blocking stdin lines, classifies each one heuristically:
"stop" / "quit" / "exit" → state["stop"] = True
ends with "?" → user_vqa_query event
starts with "task:" or first line → set runtime task
anything else → user_interjection event
Plugged into the runtime via ``event_collector=StdinReader().poll``.
"""
from __future__ import annotations
import select
import sys
from dataclasses import dataclass, field
from typing import Any
@dataclass
class StdinReader:
"""Non-blocking stdin line collector for the runtime loop."""
prompt: str = "> "
_seen_first_line: bool = field(default=False, init=False)
_prompted: bool = field(default=False, init=False)
def poll(self, state: dict[str, Any]) -> None:
"""Drain pending stdin lines into runtime events."""
# Print the input prompt once on every fresh tick if we don't
# already have a pending line; matches the expected REPL feel.
if not self._prompted:
print(self.prompt, end="", flush=True)
self._prompted = True
# ``select`` with timeout=0 makes this non-blocking. Only works
# for actual TTY / pipe stdins; CI / scripted runs hit EOF.
try:
ready, _, _ = select.select([sys.stdin], [], [], 0)
except (ValueError, OSError):
return
if not ready:
return
line = sys.stdin.readline()
if not line: # EOF
state["stop"] = True
return
line = line.strip()
self._prompted = False # we'll re-prompt next tick
if not line:
return
lower = line.lower()
if lower in {"stop", "quit", "exit"}:
state["stop"] = True
return
# First non-control line sets the task if no task is active.
if not state.get("task"):
task = line[5:].strip() if lower.startswith("task:") else line
state["task"] = task
print(f"[smolvla2] Task: {task}", flush=True)
self._seen_first_line = True
return
# Question → VQA; statement → interjection.
if lower.endswith("?"):
state["recent_vqa_query"] = line
state.setdefault("events_this_tick", []).append("user_vqa_query")
else:
state["recent_interjection"] = line
state.setdefault("events_this_tick", []).append("user_interjection")
@@ -0,0 +1,143 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""SmolVLA2 runtime loop.
Threads the multi-rate inference pipeline together with a stdin REPL
event collector, drives ticks through :class:`TickClock`, and prints
state-change updates to the user.
"""
from __future__ import annotations
import logging
from collections import deque
from dataclasses import dataclass, field
from typing import Any, Callable
from .runtime_state import initial_runtime_state, push_log
from .steps import (
AskVQAFwd,
DispatchAction,
DispatchToolCalls,
HighLevelSubtaskFwd,
InferenceStep,
LowLevelForward,
MemoryUpdateFwd,
UserInterjectionFwd,
)
from .triggers import HzTrigger, TickClock
logger = logging.getLogger(__name__)
@dataclass
class SmolVLA2Runtime:
"""Compose the inference pipeline and drive it tick-by-tick."""
policy: Any
tools: dict[str, Any] = field(default_factory=dict)
"""Name → tool-instance dict, e.g. ``{"say": SayTool(...)}``. Read
from :func:`lerobot.tools.get_tools(meta)` when wiring the
runtime."""
observation_provider: Callable[[], dict | None] | None = None
"""Closure returning the current preprocessed observation batch.
``None`` for dry-run / language-only sessions."""
robot_executor: Callable[[Any], None] | None = None
"""Closure that takes one action chunk and forwards it to the
robot. ``None`` for dry-run."""
event_collector: Callable[[dict], None] | None = None
"""Per-tick hook that polls external sources (stdin, network) and
appends event names to ``state["events_this_tick"]``."""
chunk_hz: float = 4.0
ctrl_hz: float = 50.0
high_level_hz: float = 1.0
max_rate_hz: float = 50.0
pipeline: list[InferenceStep] = field(init=False)
state: dict[str, Any] = field(init=False)
_stop: bool = field(default=False, init=False)
def __post_init__(self) -> None:
self.pipeline = [
LowLevelForward(
trigger=HzTrigger(self.chunk_hz),
policy=self.policy,
observation_provider=self.observation_provider,
),
DispatchAction(
trigger=HzTrigger(self.ctrl_hz),
robot_executor=self.robot_executor,
),
HighLevelSubtaskFwd(
trigger=HzTrigger(self.high_level_hz),
policy=self.policy,
),
MemoryUpdateFwd(policy=self.policy),
UserInterjectionFwd(policy=self.policy),
AskVQAFwd(policy=self.policy),
DispatchToolCalls(tools=self.tools),
]
self.state = initial_runtime_state()
# ------------------------------------------------------------------
# Lifecycle
# ------------------------------------------------------------------
def set_task(self, task: str) -> None:
"""Set or replace the active task. Logged for the REPL."""
self.state["task"] = task
push_log(self.state, f"Task: {task}")
def stop(self) -> None:
self._stop = True
def run(self, *, max_ticks: int | None = None) -> None:
"""Main loop. Returns when ``stop()`` is called or after
``max_ticks`` ticks (useful for tests / dry-run)."""
clock = TickClock(max_rate_hz=self.max_rate_hz)
while not self._stop:
tick = clock.advance()
self.state["_tick"] = tick
self.state["events_this_tick"] = []
self.state["log_lines"] = []
if self.event_collector is not None:
self.event_collector(self.state)
if self.state.get("stop"):
self._stop = True
break
for step in self.pipeline:
self.state = step(self.state)
self._flush_logs()
if max_ticks is not None and tick.index >= max_ticks:
break
self._on_shutdown()
# ------------------------------------------------------------------
# I/O
# ------------------------------------------------------------------
def _flush_logs(self) -> None:
for line in self.state.get("log_lines") or []:
print(f"[smolvla2] {line}", flush=True)
def _on_shutdown(self) -> None:
# Drain any queued action chunks safely.
queue = self.state.get("action_queue")
if isinstance(queue, deque):
queue.clear()
print("[smolvla2] runtime stopped", flush=True)
@@ -0,0 +1,91 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Runtime state passed between inference steps each tick.
The runtime threads a single dict through the pipeline; this module
documents the shape and provides factories. We use a plain ``dict``
rather than a frozen dataclass because steps freely add and remove
keys (``events_this_tick``, ``messages_pending``, ``tool_calls_pending``,
…) and dataclass field churn would just get in the way.
Stable keys (read by multiple steps):
task str the current top-level task
current_plan str | None latest plan emitted by the planner
current_subtask str | None latest subtask the policy is executing
current_memory str | None latest compressed memory
recent_interjection str | None most recent user interjection text (consumed)
action_queue collections.deque[Tensor] pending action chunks
tool_calls_pending list[dict] parsed but not-yet-dispatched tool calls
events_this_tick list[str] triggers consumed this tick
_tick Tick current tick (set by the loop)
log_lines list[str] human-readable status lines printed each tick
"""
from __future__ import annotations
from collections import deque
from typing import Any
def initial_runtime_state(task: str | None = None) -> dict[str, Any]:
"""Build a fresh runtime state dict with sensible defaults."""
return {
"task": task,
"current_plan": None,
"current_subtask": None,
"current_memory": None,
"recent_interjection": None,
"action_queue": deque(),
"tool_calls_pending": [],
"events_this_tick": [],
"log_lines": [],
"stop": False,
}
def take_event(state: dict[str, Any], event_name: str) -> bool:
"""Pop ``event_name`` from ``events_this_tick`` if present.
Steps that consume an event call this so the same event doesn't
re-fire on a sibling step within the same tick.
"""
events: list[str] = state.get("events_this_tick") or []
if event_name in events:
events.remove(event_name)
return True
return False
def push_log(state: dict[str, Any], line: str) -> None:
"""Append ``line`` to the per-tick log buffer; the runtime prints
it at the end of the tick."""
state.setdefault("log_lines", []).append(line)
def set_if_changed(state: dict[str, Any], key: str, value: Any, label: str | None = None) -> bool:
"""Update ``state[key]`` and log a diff line if the value changed.
Returns ``True`` if the value actually changed.
"""
prev = state.get(key)
if prev == value:
return False
state[key] = value
if label is not None:
push_log(state, f" {label}: {value}")
return True
@@ -0,0 +1,382 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference steps for the SmolVLA2 multi-rate runtime.
Each step is a tiny class with a ``trigger`` and an ``__call__(state)``;
the runtime applies them in order each tick. When a step's trigger
doesn't fire, the step is a no-op and the runtime moves on.
Stream-to-step mapping mirrors the ``smolvla2_hirobot.yaml`` recipe:
* ``LowLevelForward`` — calls ``policy.select_action`` for the
action chunk; trained by
``low_level_execution``
* ``EnqueueChunk`` — pushes the chunk to ``action_queue``
* ``DispatchAction`` — pops one action per control tick and
forwards to the robot
* ``HighLevelSubtaskFwd`` — calls ``policy.select_message`` for the
next subtask; trained by
``high_level_subtask``
* ``MemoryUpdateFwd`` — fires on subtask boundary; trained by
``memory_update``
* ``UserInterjectionFwd`` — fires on stdin interjection; trained by
``user_interjection_response``
* ``AskVQAFwd`` — fires on stdin question; trained by
``ask_vqa_*``
* ``DispatchToolCalls`` — pops ``tool_calls_pending`` and calls
the matching ``Tool`` instance
"""
from __future__ import annotations
import logging
import re
from dataclasses import dataclass, field
from typing import Any
from .runtime_state import push_log, set_if_changed, take_event
from .triggers import EventTrigger, HzTrigger, Trigger
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Step base + runner
# ---------------------------------------------------------------------------
@dataclass
class InferenceStep:
"""A trigger-gated callable. Subclasses override :meth:`run`."""
trigger: Trigger
def __call__(self, state: dict[str, Any]) -> dict[str, Any]:
if not self.trigger.should_fire(state["_tick"], state):
return state
return self.run(state) or state
def run(self, state: dict[str, Any]) -> dict[str, Any] | None: # pragma: no cover
raise NotImplementedError
# ---------------------------------------------------------------------------
# Low-level (action) path
# ---------------------------------------------------------------------------
@dataclass
class LowLevelForward(InferenceStep):
"""Run the policy's action head and produce one action chunk."""
policy: Any = None
observation_provider: Any = None
"""Callable ``() -> dict``: returns the current observation batch
(already preprocessed). Typically wraps the robot's camera /
proprio reads. ``None`` in dry-run mode → step skips."""
trigger: Trigger = field(default_factory=lambda: HzTrigger(hz=4.0))
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
if self.policy is None or self.observation_provider is None:
return None
observation = self.observation_provider()
if observation is None:
return None
action = self.policy.select_action(observation)
# SmolVLA returns a single action; if the underlying policy
# streams chunks, split per-step here. For v1 we just enqueue
# the result.
state.setdefault("action_queue", []).append(action)
return None
@dataclass
class DispatchAction(InferenceStep):
"""Pop one action per tick and hand it to the robot.
In dry-run mode (``robot_executor=None``) the step still pops the
queue so it doesn't grow unbounded — the popped tensor is logged
instead of executed.
"""
robot_executor: Any = None
trigger: Trigger = field(default_factory=lambda: HzTrigger(hz=50.0))
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
queue = state.get("action_queue")
if not queue:
return None
action = queue.popleft() if hasattr(queue, "popleft") else queue.pop(0)
if self.robot_executor is not None:
self.robot_executor(action)
return None
# ---------------------------------------------------------------------------
# High-level (text) paths — all use policy.select_message
# ---------------------------------------------------------------------------
def _build_text_batch(policy: Any, prompt_messages: list[dict[str, Any]]) -> dict[str, Any]:
"""Tokenize a list of chat messages into the batch shape
``select_message`` expects.
Lazy fallback: re-uses the policy's preprocessor by piggy-backing
on the chat tokenizer step. Production use should construct the
batch from a real observation; here we focus on the *language*
path which is independent of camera observations.
"""
from transformers import AutoTokenizer # noqa: PLC0415
tokenizer = AutoTokenizer.from_pretrained(policy.config.vlm_model_name)
if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
tokenizer.pad_token = tokenizer.eos_token
text_messages = [_strip_recipe_keys(m) for m in prompt_messages]
ids = tokenizer.apply_chat_template(
text_messages,
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
)
if isinstance(ids, list):
ids = ids[0] if ids else []
if hasattr(ids, "ndim") and ids.ndim == 1:
ids = ids.unsqueeze(0)
attn = (ids != tokenizer.pad_token_id) if tokenizer.pad_token_id is not None else None
return {"lang_tokens": ids, "lang_masks": attn, "tokenizer": tokenizer}
def _strip_recipe_keys(m: dict[str, Any]) -> dict[str, Any]:
new = dict(m)
new.pop("stream", None)
new.pop("target", None)
return new
@dataclass
class HighLevelSubtaskFwd(InferenceStep):
"""At ~1 Hz, ask the policy for the next subtask."""
policy: Any = None
trigger: Trigger = field(default_factory=lambda: HzTrigger(hz=1.0))
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
if self.policy is None or not state.get("task"):
return None
ctx = _control_context_messages(state)
msg = _generate_with_policy(self.policy, ctx)
if msg:
changed = set_if_changed(state, "current_subtask", msg, label="subtask")
if changed:
# Subtask change is a downstream trigger.
state.setdefault("events_this_tick", []).append("subtask_change")
return None
@dataclass
class MemoryUpdateFwd(InferenceStep):
"""On subtask boundary, refresh the compressed memory."""
policy: Any = None
trigger: Trigger = field(default_factory=lambda: EventTrigger("subtask_change"))
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
# Don't consume the event — multiple steps may want to react.
if self.policy is None:
return None
ctx = _control_context_messages(state, include_completed=True)
new_memory = _generate_with_policy(self.policy, ctx)
if new_memory:
set_if_changed(state, "current_memory", new_memory, label="memory")
return None
@dataclass
class UserInterjectionFwd(InferenceStep):
"""On stdin interjection, refresh the plan + emit a paired ``say``."""
policy: Any = None
trigger: Trigger = field(default_factory=lambda: EventTrigger("user_interjection"))
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
if self.policy is None or not take_event(state, "user_interjection"):
return None
ctx = _control_context_messages(
state,
extra_user=state.get("recent_interjection"),
)
out = _generate_with_policy(self.policy, ctx)
if not out:
return None
# Heuristic split: model is trained to emit one assistant turn
# carrying both plan text AND a `say` tool call. Look for a
# "<say>...</say>" or "say(...)" marker; fall back to whole
# text → plan, no speech.
plan_text, speech_text = _split_plan_and_say(out)
if plan_text:
set_if_changed(state, "current_plan", plan_text, label="plan")
if speech_text:
push_log(state, f" speech: {speech_text}")
state.setdefault("tool_calls_pending", []).append(
{
"type": "function",
"function": {"name": "say", "arguments": {"text": speech_text}},
}
)
state.setdefault("events_this_tick", []).append("tool_call_pending")
# Mark interjection consumed.
state["recent_interjection"] = None
return None
@dataclass
class AskVQAFwd(InferenceStep):
"""On stdin question, answer a frame-grounded VQA."""
policy: Any = None
trigger: Trigger = field(default_factory=lambda: EventTrigger("user_vqa_query"))
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
if self.policy is None or not take_event(state, "user_vqa_query"):
return None
question = state.get("recent_vqa_query")
if not question:
return None
ctx = _control_context_messages(state, extra_user=question)
answer = _generate_with_policy(self.policy, ctx)
if answer:
push_log(state, f" vqa: {answer}")
state["recent_vqa_query"] = None
return None
# ---------------------------------------------------------------------------
# Tool dispatch
# ---------------------------------------------------------------------------
@dataclass
class DispatchToolCalls(InferenceStep):
"""Pop ``tool_calls_pending`` and execute them via :data:`TOOL_REGISTRY`."""
tools: dict[str, Any] = field(default_factory=dict)
trigger: Trigger = field(default_factory=lambda: EventTrigger("tool_call_pending"))
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
take_event(state, "tool_call_pending")
pending = state.get("tool_calls_pending") or []
for call in pending:
try:
fn = (call or {}).get("function") or {}
name = fn.get("name")
args = fn.get("arguments") or {}
tool = self.tools.get(name)
if tool is None:
push_log(state, f" [warn] tool {name!r} not registered — skipping call")
continue
tool.call(args)
except Exception as exc: # noqa: BLE001
push_log(state, f" [error] tool dispatch failed: {exc}")
state["tool_calls_pending"] = []
return None
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _control_context_messages(
state: dict[str, Any],
*,
include_completed: bool = False,
extra_user: str | None = None,
) -> list[dict[str, Any]]:
"""Build a chat-template-ready prompt from current runtime state.
Mirrors what ``smolvla2_hirobot.yaml`` renders into ``${task}\nPlan:
${plan}\nMemory: ${memory}`` for the high-level branches.
"""
parts: list[str] = []
task = state.get("task") or ""
parts.append(task)
if state.get("current_plan"):
parts.append(f"Plan: {state['current_plan']}")
if state.get("current_memory"):
parts.append(f"Memory: {state['current_memory']}")
if include_completed and state.get("current_subtask"):
parts.append(f"Completed subtask: {state['current_subtask']}")
head = "\n".join(parts)
msgs: list[dict[str, Any]] = [{"role": "user", "content": head}]
if extra_user:
msgs.append({"role": "user", "content": extra_user})
return msgs
def _generate_with_policy(policy: Any, messages: list[dict[str, Any]]) -> str:
"""Drive ``policy.select_message`` with a minimal text-only batch.
Best-effort: the runtime today doesn't construct a full
observation batch with images / state for text generation; the
text-head was trained over images + lang + state, so generations
here may differ in distribution from training. This is acceptable
for a v1 REPL; a follow-up will plug in the real observation.
"""
if not hasattr(policy, "select_message"):
return ""
text_batch = _build_text_batch(policy, messages)
# ``select_message`` expects a real batch with OBS_LANGUAGE_TOKENS.
# The minimal text-only batch we build doesn't have images / state,
# so we either run a text-only forward (handled by SmolVLA2 when
# supported) or skip and return empty. v1 returns empty when the
# policy can't handle it; the runtime logs and continues.
try:
# Convert to the OBS_LANGUAGE_TOKENS / OBS_LANGUAGE_ATTENTION_MASK
# keys ``select_message`` uses internally.
from lerobot.utils.constants import ( # noqa: PLC0415
OBS_LANGUAGE_ATTENTION_MASK,
OBS_LANGUAGE_TOKENS,
)
batch = {
OBS_LANGUAGE_TOKENS: text_batch["lang_tokens"],
OBS_LANGUAGE_ATTENTION_MASK: text_batch["lang_masks"],
}
return policy.select_message(batch, tokenizer=text_batch["tokenizer"])
except Exception as exc: # noqa: BLE001
logger.debug("select_message fell back: %s", exc)
return ""
_SAY_RE = re.compile(r"<\s*say\s*>(.*?)<\s*/\s*say\s*>", re.IGNORECASE | re.DOTALL)
def _split_plan_and_say(text: str) -> tuple[str, str]:
"""Pull a ``<say>...</say>`` snippet out of ``text``; remainder is plan.
The training-time tool-call serializer wraps ``say(text="")`` in a
deterministic textual marker so prefix-LM-style training learns to
emit it. The runtime parses it back here. If no marker is present,
the entire text is treated as plan with no speech.
"""
if not text:
return "", ""
match = _SAY_RE.search(text)
if not match:
return text.strip(), ""
speech = match.group(1).strip().strip('"').strip("'")
plan = (text[: match.start()] + text[match.end() :]).strip()
return plan, speech
@@ -0,0 +1,117 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Trigger primitives for SmolVLA2's multi-rate inference runtime.
Mirrors the plan's Section "Runtime orchestration": each
``InferenceStep`` is gated by a :class:`Trigger` that decides per tick
whether the step fires. Two trigger flavours cover all the cadences
the canonical recipe needs:
* :class:`HzTrigger` for periodic beats (action chunks at ~3-5 Hz,
high-level subtask generation at ~1 Hz, action dispatch at ~50 Hz)
* :class:`EventTrigger` for one-shot reactions (subtask boundary →
memory update; user interjection → plan refresh; user VQA query →
vqa answer; pending tool call → dispatcher)
Triggers are stateless except for ``HzTrigger``'s last-fire timestamp.
The runtime stores the :class:`Tick` clock as ``state["_tick"]`` so
every step shares a single time source.
"""
from __future__ import annotations
import time
from dataclasses import dataclass, field
from typing import Any, Protocol
@dataclass
class Tick:
"""Single tick from :class:`TickClock`. Carries time references the
runtime steps consume to gate themselves."""
index: int
"""Monotonic counter — increments by one per tick."""
monotonic_seconds: float
"""``time.monotonic()`` at the start of this tick."""
@dataclass
class TickClock:
"""Drives the runtime loop at up to ``max_rate_hz``.
Sleeps just enough between :meth:`advance` calls to enforce the
rate. With ``max_rate_hz=50`` the loop wakes ~every 20ms; the
higher-level ``HzTrigger`` slices that timeline into sub-cadences.
"""
max_rate_hz: float = 50.0
_index: int = field(default=0, init=False)
_last_seconds: float | None = field(default=None, init=False)
def advance(self) -> Tick:
period = 1.0 / max(self.max_rate_hz, 0.1)
now = time.monotonic()
if self._last_seconds is not None:
sleep_for = (self._last_seconds + period) - now
if sleep_for > 0:
time.sleep(sleep_for)
now = time.monotonic()
self._last_seconds = now
self._index += 1
return Tick(index=self._index, monotonic_seconds=now)
class Trigger(Protocol):
"""Decide whether the next ``InferenceStep`` should fire."""
def should_fire(self, tick: Tick, state: dict[str, Any]) -> bool: ...
@dataclass
class HzTrigger:
"""Fire at most ``hz`` times per second."""
hz: float
_last_seconds: float | None = field(default=None, init=False)
def should_fire(self, tick: Tick, state: dict[str, Any]) -> bool:
period = 1.0 / max(self.hz, 1e-6)
if self._last_seconds is None or (tick.monotonic_seconds - self._last_seconds) >= period:
self._last_seconds = tick.monotonic_seconds
return True
return False
@dataclass
class EventTrigger:
"""Fire when ``event_name`` is in ``state["events_this_tick"]``.
The runtime fills ``events_this_tick`` once per tick from:
* stdin / network input (``user_interjection``, ``user_vqa_query``,
``stop``)
* internal state transitions (``subtask_change``,
``tool_call_pending``)
The list is consumed (cleared at the end of the tick) so events
fire at most once.
"""
event_name: str
def should_fire(self, tick: Tick, state: dict[str, Any]) -> bool:
events: list[str] = state.get("events_this_tick") or []
return self.event_name in events
@@ -13,60 +13,62 @@
# limitations under the License.
"""SmolVLA2 modeling — dual-head subclass of SmolVLAPolicy.
This module defines :class:`SmolVLA2Policy`, which extends SmolVLA with:
Adds:
* an unfrozen SmolVLM ``lm_head`` so language tokens can be supervised,
* a forward path that routes to the flow head, the text head, or both,
driven by ``batch["predict_actions"]`` and ``batch["text_labels"]``.
* a forward path that runs the flow head, the text head, or both,
driven by ``batch["predict_actions"]`` and ``batch["text_labels"]``
produced by :class:`SmolVLA2ChatTokenizerStep` (the previous commit on
this branch).
The text-head computation itself is NOT wired up in this scaffold commit
(the processor doesn't yet produce ``text_labels`` either). This file is
the structural placeholder that:
Per-sample routing — within one batch:
1. registers the ``SmolVLA2Policy`` class with the right config name so
``policies/factory.py`` can build it,
2. unfreezes ``lm_head`` at construction time when the config asks for it
(otherwise SmolVLA's ``train_expert_only`` freezes it again on every
``train()`` call),
3. forwards to ``SmolVLAPolicy.forward`` so behaviour is identical to
SmolVLA when no text labels are present — i.e. existing SmolVLA
training scripts keep working.
* ``predict_actions[i] = True`` ⇒ sample ``i`` contributes to the flow
loss (action chunk supervision).
* ``predict_actions[i] = False`` ⇒ sample ``i`` is masked out of the
flow loss; only its text tokens (where ``text_labels[i, t] != -100``)
contribute to the LM-head cross-entropy.
The next commit on this branch fills in the actual text-loss path.
Falls back to ``SmolVLAPolicy.forward`` cleanly when neither
``text_labels`` nor ``predict_actions`` is in the batch — unannotated
datasets keep working unchanged.
"""
from __future__ import annotations
import math
from typing import Any
import torch
import torch.nn.functional as F
from torch import Tensor
from ..smolvla.modeling_smolvla import SmolVLAPolicy
from lerobot.utils.constants import (
ACTION,
OBS_LANGUAGE_ATTENTION_MASK,
OBS_LANGUAGE_TOKENS,
OBS_STATE,
)
from ..smolvla.modeling_smolvla import SmolVLAPolicy, make_att_2d_masks
from .configuration_smolvla2 import SmolVLA2Config
class SmolVLA2Policy(SmolVLAPolicy):
"""SmolVLA + re-enabled SmolVLM language head.
Compatible drop-in for ``SmolVLAPolicy`` from a checkpoint or factory
perspective. Behaviourally identical to SmolVLA until the text-head
code path lands in the next commit on this branch.
"""
"""SmolVLA + re-enabled SmolVLM language head."""
config_class = SmolVLA2Config
name = "smolvla2"
def __init__(self, config: SmolVLA2Config, dataset_stats: dict[str, dict[str, Tensor]] | None = None):
if not isinstance(config, SmolVLA2Config):
# Allow loading a SmolVLA checkpoint into a SmolVLA2 model by
# widening the config type — the new fields fall back to their
# defaults, which preserves the existing SmolVLA behaviour.
config = SmolVLA2Config(**{
f.name: getattr(config, f.name)
for f in config.__dataclass_fields__.values()
if hasattr(config, f.name)
})
config = SmolVLA2Config(
**{
f.name: getattr(config, f.name)
for f in config.__dataclass_fields__.values()
if hasattr(config, f.name)
}
)
super().__init__(config, dataset_stats=dataset_stats)
if config.unfreeze_lm_head and config.text_loss_weight > 0:
self._unfreeze_lm_head()
@@ -76,13 +78,8 @@ class SmolVLA2Policy(SmolVLAPolicy):
# ------------------------------------------------------------------
def _unfreeze_lm_head(self) -> None:
"""Re-enable gradients on the SmolVLM ``lm_head`` (and the bits of
the text path SmolVLA freezes) so the text-loss can flow back.
SmolVLA's ``SmolVLMWithExpertModel.set_requires_grad`` freezes
``lm_head``, ``text_model.model.norm.weight``, and the last
``text_model.layers.<N-1>`` block. We undo that selectively when
text training is enabled.
"""Re-enable gradients on the SmolVLM ``lm_head`` (and the bits
of the text path SmolVLA freezes) so the text-loss can flow back.
"""
vlm_with_expert = getattr(self.model, "vlm_with_expert", None)
if vlm_with_expert is None:
@@ -91,10 +88,7 @@ class SmolVLA2Policy(SmolVLAPolicy):
if vlm is None:
return
for name, param in vlm.named_parameters():
if (
"lm_head" in name
or "text_model.model.norm.weight" in name
):
if "lm_head" in name or "text_model.model.norm.weight" in name:
param.requires_grad = True
# ------------------------------------------------------------------
@@ -108,12 +102,286 @@ class SmolVLA2Policy(SmolVLAPolicy):
time: Tensor | None = None,
reduction: str = "mean",
) -> tuple[Tensor, dict[str, Any]]:
"""Forward pass with optional text-head loss.
"""Forward pass with optional dual-head loss.
SCAFFOLD: forwards directly to ``SmolVLAPolicy.forward``. The
actual text-loss / dual-head routing lands in the next commit on
this branch — it will read ``batch["text_labels"]`` and
``batch["predict_actions"]`` (both produced by the SmolVLA2
processor) to decide which head(s) to run.
Two routing knobs from the batch (produced by
:class:`SmolVLA2ChatTokenizerStep`):
* ``text_labels`` — per-token labels with ``-100`` for non-target
positions. Triggers the text-loss path through ``lm_head``.
* ``predict_actions`` — per-sample bool tensor. ``True`` ⇒
include this sample's action chunk in the flow loss.
When neither is present, delegate to ``SmolVLAPolicy.forward``.
"""
return super().forward(batch, noise=noise, time=time, reduction=reduction)
text_labels = batch.get("text_labels")
predict_actions_t = batch.get("predict_actions")
has_text_data = (
text_labels is not None
and isinstance(text_labels, Tensor)
and self.config.text_loss_weight > 0
)
has_per_sample_routing = (
predict_actions_t is not None and isinstance(predict_actions_t, Tensor)
)
if not has_text_data and not has_per_sample_routing:
return super().forward(batch, noise=noise, time=time, reduction=reduction)
loss_dict: dict[str, Any] = {}
device = batch[OBS_STATE].device
total = torch.zeros((), device=device, dtype=torch.float32)
# ------------------------------------------------------------
# Flow loss path — only when at least one sample wants actions.
# ------------------------------------------------------------
run_flow = self.config.flow_loss_weight > 0 and (
not has_per_sample_routing or bool(predict_actions_t.any().item())
)
if run_flow and ACTION in batch:
per_sample_flow, flow_diag = super().forward(
batch, noise=noise, time=time, reduction="none"
)
# ``per_sample_flow`` has shape (B,) from the SmolVLA
# reduction="none" branch.
if has_per_sample_routing:
mask = predict_actions_t.to(per_sample_flow.dtype)
masked = per_sample_flow * mask
denom = mask.sum().clamp(min=1.0)
flow_loss = masked.sum() / denom
else:
flow_loss = per_sample_flow.mean()
total = total + self.config.flow_loss_weight * flow_loss
loss_dict["flow_loss"] = float(flow_loss.detach().item())
for k, v in flow_diag.items():
loss_dict[f"flow_{k}"] = v
# ------------------------------------------------------------
# Text loss path — prefix-only forward → lm_head → CE.
# ------------------------------------------------------------
if has_text_data:
text_loss = self._compute_text_loss(batch, text_labels)
total = total + self.config.text_loss_weight * text_loss
loss_dict["text_loss"] = float(text_loss.detach().item())
loss_dict["loss"] = float(total.detach().item())
if reduction == "none":
# Per-sample loss isn't meaningfully defined for the dual
# path; broadcast the scalar to (B,) for caller compat.
return total.expand(batch[OBS_STATE].shape[0]), loss_dict
return total, loss_dict
# ------------------------------------------------------------------
# Text-loss internals
# ------------------------------------------------------------------
def _compute_text_loss(self, batch: dict[str, Tensor], text_labels: Tensor) -> Tensor:
"""Cross-entropy on the SmolVLM ``lm_head`` over target tokens."""
if self.config.adapt_to_pi_aloha:
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
images, img_masks = self.prepare_images(batch)
state = self.prepare_state(batch)
lang_tokens = batch[OBS_LANGUAGE_TOKENS]
lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK]
prefix_embs, prefix_pad_masks, prefix_att_masks = self.model.embed_prefix(
images, img_masks, lang_tokens, lang_masks, state=state
)
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
# Prefix-only forward.
out_pair, _ = self.model.vlm_with_expert.forward(
attention_mask=prefix_att_2d_masks,
position_ids=prefix_position_ids,
past_key_values=None,
inputs_embeds=[prefix_embs, None],
use_cache=False,
fill_kv_cache=False,
)
prefix_out = out_pair[0] if isinstance(out_pair, (tuple, list)) else out_pair
if prefix_out is None:
raise RuntimeError(
"SmolVLA2: vlm_with_expert.forward returned no prefix hidden "
"states — text-loss path needs them."
)
# Lang token positions inside the prefix. ``embed_prefix`` lays
# out the prefix as ``[image_blocks..., lang, state]`` so the
# lang range is identifiable from the trailing state size and
# the known lang length.
num_lang = lang_tokens.shape[1]
state_for_dim = state if state.ndim >= 2 else state[:, None]
num_state = state_for_dim.shape[1] if state_for_dim.ndim >= 2 else 1
if num_state < 1:
num_state = 1
prefix_len = prefix_out.shape[1]
lang_end = prefix_len - num_state
lang_start = lang_end - num_lang
if lang_start < 0 or lang_end > prefix_len:
raise RuntimeError(
f"SmolVLA2: could not locate lang token range in prefix "
f"(prefix_len={prefix_len}, num_lang={num_lang}, "
f"num_state={num_state})."
)
lang_hidden = prefix_out[:, lang_start:lang_end]
vlm = self.model.vlm_with_expert.vlm
logits = vlm.lm_head(lang_hidden) # (B, num_lang, vocab)
if text_labels.shape[1] != num_lang:
common = min(text_labels.shape[1], num_lang)
logits = logits[:, :common]
text_labels = text_labels[:, :common]
# Standard next-token CE: hidden state at position t predicts
# token at position t+1. Shift logits left, labels right by 1.
# Without this, the loss is identity-mapped and the LM head
# learns nothing useful — see HuggingFace ``LlamaForCausalLM``
# for the same convention.
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = text_labels[:, 1:].contiguous().long()
loss = F.cross_entropy(
shift_logits.reshape(-1, shift_logits.shape[-1]),
shift_labels.reshape(-1),
ignore_index=-100,
)
return loss
# ------------------------------------------------------------------
# Inference: text generation
# ------------------------------------------------------------------
@torch.no_grad()
def select_message(
self,
batch: dict[str, Tensor],
*,
max_new_tokens: int = 256,
eos_token_id: int | None = None,
temperature: float = 0.0,
top_p: float = 1.0,
tokenizer: Any = None,
) -> str:
"""Generate text continuation from the chat-templated prompt.
AR decoding with KV caching reused from SmolVLA's inference
path. Batch size is assumed to be 1 (the runtime calls this
per-event). Returns the decoded string of new tokens (the
prompt itself is not included).
Parameters
----------
batch:
Already through the SmolVLA2 preprocessor — expects
``OBS_IMAGES_*``, ``OBS_STATE``, ``OBS_LANGUAGE_TOKENS``,
``OBS_LANGUAGE_ATTENTION_MASK``.
max_new_tokens:
Hard cap on generated tokens; stops earlier on EOS.
eos_token_id:
Override the tokenizer's EOS. ``None`` ⇒ use the
tokenizer's default.
temperature, top_p:
``temperature=0`` does greedy argmax (default — matches
training distribution most closely). Set ``temperature>0``
with optional ``top_p<1`` for nucleus sampling.
tokenizer:
Optional pre-loaded tokenizer to avoid the cold-start
``AutoTokenizer.from_pretrained`` round-trip on every call.
"""
self.eval()
if tokenizer is None:
from transformers import AutoTokenizer # noqa: PLC0415
tokenizer = AutoTokenizer.from_pretrained(self.config.vlm_model_name)
if eos_token_id is None:
eos_token_id = tokenizer.eos_token_id
images, img_masks = self.prepare_images(batch)
state = self.prepare_state(batch)
lang_tokens = batch[OBS_LANGUAGE_TOKENS]
lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK]
# 1) Embed prefix (images + lang + state) and run with KV cache.
prefix_embs, prefix_pad_masks, prefix_att_masks = self.model.embed_prefix(
images, img_masks, lang_tokens, lang_masks, state=state
)
prefix_2d = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
prefix_pos = torch.cumsum(prefix_pad_masks, dim=1) - 1
out_pair, past_kv = self.model.vlm_with_expert.forward(
attention_mask=prefix_2d,
position_ids=prefix_pos,
past_key_values=None,
inputs_embeds=[prefix_embs, None],
use_cache=True,
fill_kv_cache=True,
)
prefix_out = out_pair[0] if isinstance(out_pair, (tuple, list)) else out_pair
if prefix_out is None:
raise RuntimeError("select_message: prefix forward returned no hidden states.")
vlm = self.model.vlm_with_expert.vlm
# 2) Initial logits — sample first new token from the last
# prefix position.
last_hidden = prefix_out[:, -1:]
device = last_hidden.device
bsize = prefix_embs.shape[0]
cur_pos = int(prefix_embs.shape[1])
generated: list[int] = []
for _ in range(max_new_tokens):
logits_step = vlm.lm_head(last_hidden)[:, -1] # (B, V)
next_ids = self._sample_next_token(logits_step, temperature, top_p)
tok_id = int(next_ids[0].item())
generated.append(tok_id)
if eos_token_id is not None and tok_id == eos_token_id:
break
# 3) Embed the new token and forward with KV cache.
new_emb = self.model.vlm_with_expert.embed_language_tokens(
next_ids.unsqueeze(0)
)
new_emb = new_emb * math.sqrt(new_emb.shape[-1])
new_pos = torch.full((bsize, 1), cur_pos, device=device, dtype=torch.long)
new_attn = torch.ones((bsize, cur_pos + 1), device=device, dtype=torch.bool)
out_pair, past_kv = self.model.vlm_with_expert.forward(
attention_mask=new_attn,
position_ids=new_pos,
past_key_values=past_kv,
inputs_embeds=[new_emb, None],
use_cache=True,
fill_kv_cache=True,
)
new_prefix_out = out_pair[0] if isinstance(out_pair, (tuple, list)) else out_pair
last_hidden = new_prefix_out[:, -1:]
cur_pos += 1
return tokenizer.decode(generated, skip_special_tokens=True).strip()
@staticmethod
def _sample_next_token(
logits: Tensor, temperature: float, top_p: float
) -> Tensor:
"""Pick one token id per batch row from ``logits``."""
if temperature <= 0.0:
return logits.argmax(dim=-1)
scaled = logits / max(temperature, 1e-6)
probs = F.softmax(scaled, dim=-1)
if top_p < 1.0:
sorted_probs, sorted_idx = probs.sort(dim=-1, descending=True)
cum = sorted_probs.cumsum(dim=-1)
mask = cum > top_p
# Always keep the most-likely token.
mask[..., 0] = False
sorted_probs = sorted_probs.masked_fill(mask, 0.0)
sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True).clamp(min=1e-9)
pick = torch.multinomial(sorted_probs, num_samples=1)
return sorted_idx.gather(-1, pick).squeeze(-1)
return torch.multinomial(probs, num_samples=1).squeeze(-1)
@@ -0,0 +1,196 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""``lerobot-smolvla2-runtime`` — interactive REPL for trained SmolVLA2.
Drives the multi-rate runtime defined in
:mod:`lerobot.policies.smolvla2.inference`. Stdin becomes the user
channel: type a task, then natural-language interjections / questions.
The runtime prints state changes (plan / subtask / memory / vqa /
speech) as they happen.
Examples
--------
Dry run on a checkpoint, no robot connected — useful for sanity-
checking text generation::
uv run lerobot-smolvla2-runtime \\
--policy.path=outputs/train/smolvla2_super_poulain/000020000/pretrained_model \\
--no_robot \\
--task="please clean the kitchen"
With a real robot::
uv run lerobot-smolvla2-runtime \\
--policy.path=... \\
--robot.type=so101 --robot.port=/dev/tty.usbmodem... \\
--tts.voice=alba
Tool dispatch (TTS via ``SayTool``) is enabled by default when
``pocket-tts`` is installed; pass ``--no_tts`` to disable.
"""
from __future__ import annotations
import argparse
import logging
import sys
from pathlib import Path
from typing import Any
logger = logging.getLogger("lerobot.smolvla2.runtime")
def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
p = argparse.ArgumentParser(
prog="lerobot-smolvla2-runtime",
description="Interactive REPL runtime for a trained SmolVLA2 checkpoint.",
)
p.add_argument(
"--policy.path",
dest="policy_path",
type=Path,
required=True,
help="Path to a trained SmolVLA2 ``pretrained_model`` directory.",
)
p.add_argument(
"--task",
dest="task",
type=str,
default=None,
help="Initial task. If omitted, the first stdin line is treated as the task.",
)
p.add_argument(
"--no_robot",
action="store_true",
help="Skip robot connection — language-only / dry-run mode.",
)
p.add_argument(
"--no_tts",
action="store_true",
help="Disable the ``say`` tool dispatch.",
)
p.add_argument(
"--tts.voice",
dest="tts_voice",
type=str,
default="alba",
help="Pocket-tts voice name (or path to a .wav for cloning).",
)
p.add_argument(
"--chunk_hz", type=float, default=4.0, help="Action-chunk generation rate."
)
p.add_argument(
"--ctrl_hz", type=float, default=50.0, help="Action dispatch rate."
)
p.add_argument(
"--high_level_hz",
type=float,
default=1.0,
help="High-level subtask generation rate.",
)
p.add_argument(
"--max_ticks",
type=int,
default=None,
help="Stop after N ticks (debug / smoke-test).",
)
p.add_argument("-v", "--verbose", action="store_true", help="Enable DEBUG logging.")
return p.parse_args(argv)
def _load_policy(path: Path): # noqa: ANN202
"""Load a SmolVLA2 checkpoint from ``path``."""
from lerobot.policies.factory import make_policy_from_path # noqa: PLC0415
policy = make_policy_from_path(str(path))
policy.eval()
return policy
def _build_tools(policy_path: Path, no_tts: bool, tts_voice: str) -> dict[str, Any]:
"""Instantiate the tools declared on this dataset/policy."""
if no_tts:
return {}
try:
from lerobot.tools import SayTool # noqa: PLC0415
return {"say": SayTool(voice=tts_voice)}
except Exception as exc: # noqa: BLE001
logger.warning("Could not initialise SayTool (%s) — speech disabled.", exc)
return {}
def main(argv: list[str] | None = None) -> int:
args = _parse_args(argv)
logging.basicConfig(
level=logging.DEBUG if args.verbose else logging.INFO,
format="%(asctime)s %(levelname)s %(message)s",
)
if not args.policy_path.exists():
print(f"[smolvla2] policy path not found: {args.policy_path}", file=sys.stderr)
return 1
print(f"[smolvla2] loading policy from {args.policy_path}", flush=True)
policy = _load_policy(args.policy_path)
tools = _build_tools(args.policy_path, args.no_tts, args.tts_voice)
if tools:
print(f"[smolvla2] tools loaded: {list(tools)}", flush=True)
# Robot wiring is left as a follow-up — for v1 we run language-only
# / dry-run so REPL development doesn't require a connected robot.
observation_provider = None
robot_executor = None
if not args.no_robot:
print(
"[smolvla2] WARNING: real-robot integration is a follow-up. "
"Running in dry-run mode for now (no actions executed).",
flush=True,
)
from lerobot.policies.smolvla2.inference import ( # noqa: PLC0415
SmolVLA2Runtime,
StdinReader,
)
runtime = SmolVLA2Runtime(
policy=policy,
tools=tools,
observation_provider=observation_provider,
robot_executor=robot_executor,
event_collector=StdinReader().poll,
chunk_hz=args.chunk_hz,
ctrl_hz=args.ctrl_hz,
high_level_hz=args.high_level_hz,
)
if args.task:
runtime.set_task(args.task)
print(
"[smolvla2] runtime ready. Type a task to begin, then any line for "
"interjections, questions ending in '?' for VQA, or 'stop' to exit.",
flush=True,
)
try:
runtime.run(max_ticks=args.max_ticks)
except KeyboardInterrupt:
runtime.stop()
print("\n[smolvla2] interrupted by user", flush=True)
return 0
if __name__ == "__main__":
sys.exit(main())