feat(smolvla2): runtime accepts Hub IDs + dataset-driven dry-run

The runtime CLI's loader was broken — it imported a `make_policy_from_path`
that doesn't exist in `lerobot.policies.factory` — and the high-level text
steps generated plan / subtask / memory / VQA from a text-only batch with
no images or state, so dry-runs drifted from the training distribution.

Switch to the standard `PreTrainedConfig.from_pretrained` +
`make_policy(cfg, ds_meta=...)` flow so `--policy.path` accepts both local
directories and Hub repo ids, and add a `--dataset.repo_id` path that walks
a chosen episode and feeds preprocessed observations into every forward
pass — including the four high-level steps (`HighLevelSubtaskFwd`,
`MemoryUpdateFwd`, `UserInterjectionFwd`, `AskVQAFwd`). Frames are routed
through the saved preprocessor pipeline with `language_persistent` /
`language_events` stripped so the recipe-render step stays a no-op (the
runtime supplies its own messages from current state).

Also wires the rich-based two-zone REPL layout (`ui.py`) that the script
was already importing.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-05 11:08:53 +02:00
parent a1b8134ef1
commit 3fe686ce9f
5 changed files with 520 additions and 57 deletions
@@ -40,6 +40,7 @@ from .steps import (
UserInterjectionFwd,
)
from .triggers import EventTrigger, HzTrigger, Tick, TickClock, Trigger
from .ui import make_state_panel, print_robot_lines, print_user_line
__all__ = [
# runtime
@@ -65,4 +66,8 @@ __all__ = [
"UserInterjectionFwd",
"AskVQAFwd",
"DispatchToolCalls",
# UI
"make_state_panel",
"print_robot_lines",
"print_user_line",
]
@@ -82,10 +82,20 @@ class SmolVLA2Runtime:
HighLevelSubtaskFwd(
trigger=HzTrigger(self.high_level_hz),
policy=self.policy,
observation_provider=self.observation_provider,
),
MemoryUpdateFwd(
policy=self.policy,
observation_provider=self.observation_provider,
),
UserInterjectionFwd(
policy=self.policy,
observation_provider=self.observation_provider,
),
AskVQAFwd(
policy=self.policy,
observation_provider=self.observation_provider,
),
MemoryUpdateFwd(policy=self.policy),
UserInterjectionFwd(policy=self.policy),
AskVQAFwd(policy=self.policy),
DispatchToolCalls(tools=self.tools),
]
self.state = initial_runtime_state()
@@ -127,6 +137,39 @@ class SmolVLA2Runtime:
self._on_shutdown()
# ------------------------------------------------------------------
# REPL helper: drive one full pipeline pass and return its logs
# ------------------------------------------------------------------
def step_once(self) -> list[str]:
"""Run one tick of the pipeline and return the log lines.
Used by the interactive REPL: instead of a background thread,
the CLI drives ticks synchronously after each user input. Logs
are returned (not printed) so the caller can route them into
the rich-Live chat scrollback.
"""
from .triggers import Tick # noqa: PLC0415
# Synthesize a tick. We don't need the real wall-clock pacing
# here — the REPL drives the runtime, not vice versa — but
# ``HzTrigger`` uses ``tick.monotonic_seconds`` to gate, so we
# bump it generously so every Hz-triggered step considers
# itself due.
import time as _time # noqa: PLC0415
prev_index = self.state.get("_tick").index if isinstance(self.state.get("_tick"), Tick) else 0
self.state["_tick"] = Tick(index=prev_index + 1, monotonic_seconds=_time.monotonic())
self.state["log_lines"] = []
# ``events_this_tick`` is set up by the caller before
# ``step_once`` (the REPL pushes user-driven events first).
self.state.setdefault("events_this_tick", [])
for step in self.pipeline:
self.state = step(self.state)
return list(self.state.get("log_lines") or [])
# ------------------------------------------------------------------
# I/O
# ------------------------------------------------------------------
@@ -171,13 +171,19 @@ class HighLevelSubtaskFwd(InferenceStep):
"""At ~1 Hz, ask the policy for the next subtask."""
policy: Any = None
observation_provider: Any = None
"""Same shape as ``LowLevelForward.observation_provider``. When
set, the resulting observation is merged into ``select_message``'s
batch so text generation runs against real video + state."""
trigger: Trigger = field(default_factory=lambda: HzTrigger(hz=1.0))
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
if self.policy is None or not state.get("task"):
return None
ctx = _control_context_messages(state)
msg = _generate_with_policy(self.policy, ctx)
observation = _maybe_observation(self.observation_provider)
msg = _generate_with_policy(self.policy, ctx, observation=observation)
if msg:
changed = set_if_changed(state, "current_subtask", msg, label="subtask")
if changed:
@@ -191,6 +197,7 @@ class MemoryUpdateFwd(InferenceStep):
"""On subtask boundary, refresh the compressed memory."""
policy: Any = None
observation_provider: Any = None
trigger: Trigger = field(default_factory=lambda: EventTrigger("subtask_change"))
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
@@ -198,7 +205,8 @@ class MemoryUpdateFwd(InferenceStep):
if self.policy is None:
return None
ctx = _control_context_messages(state, include_completed=True)
new_memory = _generate_with_policy(self.policy, ctx)
observation = _maybe_observation(self.observation_provider)
new_memory = _generate_with_policy(self.policy, ctx, observation=observation)
if new_memory:
set_if_changed(state, "current_memory", new_memory, label="memory")
return None
@@ -209,6 +217,7 @@ class UserInterjectionFwd(InferenceStep):
"""On stdin interjection, refresh the plan + emit a paired ``say``."""
policy: Any = None
observation_provider: Any = None
trigger: Trigger = field(default_factory=lambda: EventTrigger("user_interjection"))
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
@@ -218,7 +227,8 @@ class UserInterjectionFwd(InferenceStep):
state,
extra_user=state.get("recent_interjection"),
)
out = _generate_with_policy(self.policy, ctx)
observation = _maybe_observation(self.observation_provider)
out = _generate_with_policy(self.policy, ctx, observation=observation)
if not out:
return None
# Heuristic split: model is trained to emit one assistant turn
@@ -247,6 +257,7 @@ class AskVQAFwd(InferenceStep):
"""On stdin question, answer a frame-grounded VQA."""
policy: Any = None
observation_provider: Any = None
trigger: Trigger = field(default_factory=lambda: EventTrigger("user_vqa_query"))
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
@@ -256,7 +267,8 @@ class AskVQAFwd(InferenceStep):
if not question:
return None
ctx = _control_context_messages(state, extra_user=question)
answer = _generate_with_policy(self.policy, ctx)
observation = _maybe_observation(self.observation_provider)
answer = _generate_with_policy(self.policy, ctx, observation=observation)
if answer:
push_log(state, f" vqa: {answer}")
state["recent_vqa_query"] = None
@@ -326,35 +338,54 @@ def _control_context_messages(
return msgs
def _generate_with_policy(policy: Any, messages: list[dict[str, Any]]) -> str:
"""Drive ``policy.select_message`` with a minimal text-only batch.
def _maybe_observation(provider: Any) -> dict | None:
"""Pull one observation from ``provider`` if it's set, else ``None``.
Best-effort: the runtime today doesn't construct a full
observation batch with images / state for text generation; the
text-head was trained over images + lang + state, so generations
here may differ in distribution from training. This is acceptable
for a v1 REPL; a follow-up will plug in the real observation.
Errors from the provider are logged at debug level and swallowed —
text generation still runs (in text-only mode) so a flaky frame
source doesn't kill the REPL.
"""
if provider is None:
return None
try:
return provider()
except Exception as exc: # noqa: BLE001
logger.debug("observation_provider raised %s — falling back to text-only", exc)
return None
def _generate_with_policy(
policy: Any,
messages: list[dict[str, Any]],
*,
observation: dict | None = None,
) -> str:
"""Drive ``policy.select_message`` with a chat batch (and optional obs).
When ``observation`` carries ``observation.images.*`` and
``observation.state``, those are merged into the batch so
``select_message`` runs the same VLM prefix the policy was trained
on. Without an observation the runtime falls back to a text-only
prompt — the text head still runs, but generations may drift from
the training distribution.
"""
if not hasattr(policy, "select_message"):
return ""
text_batch = _build_text_batch(policy, messages)
# ``select_message`` expects a real batch with OBS_LANGUAGE_TOKENS.
# The minimal text-only batch we build doesn't have images / state,
# so we either run a text-only forward (handled by SmolVLA2 when
# supported) or skip and return empty. v1 returns empty when the
# policy can't handle it; the runtime logs and continues.
try:
# Convert to the OBS_LANGUAGE_TOKENS / OBS_LANGUAGE_ATTENTION_MASK
# keys ``select_message`` uses internally.
from lerobot.utils.constants import ( # noqa: PLC0415
OBS_LANGUAGE_ATTENTION_MASK,
OBS_LANGUAGE_TOKENS,
)
batch = {
batch: dict[str, Any] = {
OBS_LANGUAGE_TOKENS: text_batch["lang_tokens"],
OBS_LANGUAGE_ATTENTION_MASK: text_batch["lang_masks"],
}
if observation:
for k, v in observation.items():
if isinstance(k, str) and k.startswith("observation.") and k not in batch:
batch[k] = v
return policy.select_message(batch, tokenizer=text_batch["tokenizer"])
except Exception as exc: # noqa: BLE001
logger.debug("select_message fell back: %s", exc)
@@ -0,0 +1,119 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rich-based REPL layout for the SmolVLA2 runtime.
Two-zone terminal layout:
[chat scrollback — user messages / robot responses, scrolls naturally]
┌── State ──────────────────────────────────────────┐
│ task please clean up the kitchen │
│ subtask grasp the handle of the sponge │
│ plan 1. grasp sponge 2. wipe 3. tidy │
│ memory sponge picked up; counter still dirty │
└───────────────────────────────────────────────────┘
> _
The state panel re-renders on every state change. Chat lines are
``console.print``'d above the live region so they accumulate naturally
in scrollback. Implemented with :class:`rich.live.Live` plus
:func:`rich.console.Console.input` for the prompt — when an input is
pending, ``rich.Live`` auto-suspends so the input doesn't fight the
panel for cursor position.
"""
from __future__ import annotations
from typing import Any
try: # rich is optional; only required for the interactive REPL.
from rich.console import Console
from rich.panel import Panel
from rich.table import Table
from rich.text import Text
_HAS_RICH = True
except ImportError: # pragma: no cover
_HAS_RICH = False
Console = Any # type: ignore[assignment]
Panel = Any # type: ignore[assignment]
Table = Any # type: ignore[assignment]
Text = Any # type: ignore[assignment]
_STATE_KEYS = (
("task", "task"),
("current_subtask", "subtask"),
("current_plan", "plan"),
("current_memory", "memory"),
)
def make_state_panel(state: dict[str, Any]) -> Any:
"""Render the persistent state panel for the live region.
Returns a :class:`rich.panel.Panel`. Caller passes it to
``Live.update(panel)`` whenever the state changes.
"""
if not _HAS_RICH:
raise RuntimeError(
"rich is required for the interactive REPL. "
"`pip install rich` (it's a transitive dep of lerobot)."
)
table = Table.grid(padding=(0, 2), expand=True)
table.add_column(justify="right", style="dim", no_wrap=True, width=10)
table.add_column(justify="left")
for key, label in _STATE_KEYS:
value = state.get(key)
if value is None:
rendered = Text("(not set)", style="dim italic")
else:
rendered = Text(str(value), style="bold")
table.add_row(label, rendered)
queue = state.get("action_queue")
queue_len = len(queue) if hasattr(queue, "__len__") else 0
pending = state.get("tool_calls_pending") or []
footer = Text.assemble(
("queued actions: ", "dim"),
(str(queue_len), "bold cyan"),
(" pending tool calls: ", "dim"),
(str(len(pending)), "bold magenta"),
)
table.add_row("", footer)
return Panel(table, title="[bold]SmolVLA2 state[/]", border_style="cyan")
def print_user_line(console: Any, line: str) -> None:
"""Append a user-typed line to the chat scrollback."""
if not _HAS_RICH:
print(f"you: {line}", flush=True)
return
console.print(f"[bold cyan]you:[/] {line}")
def print_robot_lines(console: Any, lines: list[str]) -> None:
"""Append robot/runtime log lines to the chat scrollback."""
if not _HAS_RICH:
for line in lines:
print(f"robot: {line.lstrip()}", flush=True)
return
for line in lines:
# The runtime uses leading whitespace + "label: text"; render
# the label in green and the value in default for readability.
stripped = line.lstrip()
if ":" in stripped:
label, _, value = stripped.partition(":")
console.print(f"[bold green]robot[/] [dim]({label.strip()})[/] {value.strip()}")
else:
console.print(f"[bold green]robot:[/] {stripped}")
+300 -35
View File
@@ -23,11 +23,21 @@ speech) as they happen.
Examples
--------
Dry run on a checkpoint, no robot connected — useful for sanity-
Dry run on a Hub checkpoint, no robot connected — useful for sanity-
checking text generation::
uv run lerobot-smolvla2-runtime \\
--policy.path=outputs/train/smolvla2_super_poulain/000020000/pretrained_model \\
--policy.path=pepijn223/smolvla2_hirobot_super_poulain_tool2 \\
--no_robot \\
--task="please clean the kitchen"
Same, but feed real frames from an annotated dataset so plan / subtask
/ memory / VQA generation runs against actual video + state::
uv run lerobot-smolvla2-runtime \\
--policy.path=pepijn223/smolvla2_hirobot_super_poulain_tool2 \\
--dataset.repo_id=pepijn223/super_poulain_annotated \\
--dataset.episode=0 \\
--no_robot \\
--task="please clean the kitchen"
@@ -38,6 +48,9 @@ With a real robot::
--robot.type=so101 --robot.port=/dev/tty.usbmodem... \\
--tts.voice=alba
``--policy.path`` accepts either a local directory or a Hugging Face
Hub repo id. ``--dataset.repo_id`` likewise.
Tool dispatch (TTS via ``SayTool``) is enabled by default when
``pocket-tts`` is installed; pass ``--no_tts`` to disable.
"""
@@ -47,8 +60,7 @@ from __future__ import annotations
import argparse
import logging
import sys
from pathlib import Path
from typing import Any
from typing import Any, Callable
logger = logging.getLogger("lerobot.smolvla2.runtime")
@@ -61,9 +73,51 @@ def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
p.add_argument(
"--policy.path",
dest="policy_path",
type=Path,
type=str,
required=True,
help="Path to a trained SmolVLA2 ``pretrained_model`` directory.",
help=(
"Local directory or Hugging Face Hub repo id pointing at a "
"trained SmolVLA2 ``pretrained_model``."
),
)
p.add_argument(
"--dataset.repo_id",
dest="dataset_repo_id",
type=str,
default=None,
help=(
"Optional dataset (local path or Hub repo id) used to drive "
"observations during dry-run inference. When set, the runtime "
"reads camera frames + state from the chosen episode and feeds "
"them into all forward passes — so plan / subtask / memory / "
"VQA generation see the same visual context the policy was "
"trained on."
),
)
p.add_argument(
"--dataset.episode",
dest="dataset_episode",
type=int,
default=0,
help="Episode index to walk through (default: 0).",
)
p.add_argument(
"--dataset.start_frame",
dest="dataset_start_frame",
type=int,
default=0,
help="Frame index within the episode to start from (default: 0).",
)
p.add_argument(
"--dataset.advance_per_tick",
dest="dataset_advance_per_tick",
type=int,
default=1,
help=(
"How many dataset frames to advance per runtime tick. The "
"default of 1 means the runtime walks the episode forward "
"frame by frame; set to 0 to freeze on ``start_frame``."
),
)
p.add_argument(
"--task",
@@ -111,16 +165,136 @@ def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
return p.parse_args(argv)
def _load_policy(path: Path): # noqa: ANN202
"""Load a SmolVLA2 checkpoint from ``path``."""
from lerobot.policies.factory import make_policy_from_path # noqa: PLC0415
def _load_policy_and_preprocessor(
policy_path: str,
dataset_repo_id: str | None,
) -> tuple[Any, Any, Any]:
"""Load a SmolVLA2 checkpoint (local path or Hub repo id).
When ``dataset_repo_id`` is provided, the dataset's metadata is used
to derive policy features (matching the standard
``make_policy(cfg, ds_meta=...)`` flow used by ``lerobot-train`` and
``lerobot-record``). When it isn't, we fall back to instantiating
the policy directly via ``from_pretrained`` — this skips the
feature-derivation path that ``make_policy`` insists on, but also
means we can't load the saved preprocessor pipeline (which depends
on ``input_features`` / ``output_features``). For inference-only
dry-runs this is fine; the policy still loads.
Returns ``(policy, preprocessor, ds_meta)`` where ``preprocessor``
and ``ds_meta`` may be ``None`` if no dataset was provided.
"""
from lerobot.configs import PreTrainedConfig # noqa: PLC0415
from lerobot.policies.factory import make_policy, make_pre_post_processors # noqa: PLC0415
cfg = PreTrainedConfig.from_pretrained(policy_path)
cfg.pretrained_path = policy_path
ds_meta = None
preprocessor = None
if dataset_repo_id is not None:
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata # noqa: PLC0415
ds_meta = LeRobotDatasetMetadata(dataset_repo_id)
policy = make_policy(cfg, ds_meta=ds_meta)
preprocessor, _ = make_pre_post_processors(
cfg,
pretrained_path=policy_path,
dataset_stats=ds_meta.stats,
)
else:
# No dataset: instantiate the policy class directly so we don't
# need ds_meta. This bypasses ``make_policy``'s feature-shape
# derivation, which is fine for a pretrained checkpoint where
# the saved config already carries those shapes.
from lerobot.policies.factory import get_policy_class # noqa: PLC0415
policy_cls = get_policy_class(cfg.type)
policy = policy_cls.from_pretrained(policy_path, config=cfg)
policy.to(cfg.device)
policy = make_policy_from_path(str(path))
policy.eval()
return policy
return policy, preprocessor, ds_meta
def _build_tools(policy_path: Path, no_tts: bool, tts_voice: str) -> dict[str, Any]:
def _build_observation_provider(
*,
dataset_repo_id: str,
episode: int,
start_frame: int,
advance_per_tick: int,
preprocessor: Any,
device: str,
) -> Callable[[], dict | None]:
"""Build a closure that feeds dataset frames into the runtime.
Each call returns a preprocessed observation batch (images +
state, batched, on the policy's device, normalized) suitable for
``policy.select_action`` and ``policy.select_message``. The
closure walks the chosen episode forward by ``advance_per_tick``
frames per call, looping back to the episode start when it falls
off the end.
The dataset's ``language_persistent`` / ``language_events``
columns are stripped before the sample reaches the preprocessor,
so ``RenderMessagesStep`` and ``SmolVLA2ChatTokenizerStep`` are
no-ops; the runtime supplies its own messages from current state.
"""
import torch # noqa: PLC0415
from lerobot.datasets.lerobot_dataset import LeRobotDataset # noqa: PLC0415
ds = LeRobotDataset(dataset_repo_id, episodes=[episode])
if len(ds) == 0:
raise ValueError(
f"Dataset {dataset_repo_id!r} episode {episode} is empty."
)
state = {"cursor": max(0, min(start_frame, len(ds) - 1))}
def _provider() -> dict | None:
idx = state["cursor"]
if advance_per_tick > 0:
state["cursor"] = (idx + advance_per_tick) % len(ds)
sample = ds[idx]
# Strip the language columns so the preprocessor's render step
# is a no-op — the runtime drives messages itself.
for k in ("language_persistent", "language_events"):
sample.pop(k, None)
if preprocessor is not None:
sample = preprocessor(sample)
# Keep only observation keys; the runtime's text path will
# merge these with its own lang_tokens / lang_masks.
observation = {
k: v
for k, v in sample.items()
if isinstance(k, str) and k.startswith("observation.")
}
# Defensive: if something further upstream forgot the batch
# dim, add it now so downstream Tensor ops don't crash.
for k, v in list(observation.items()):
if isinstance(v, torch.Tensor) and v.ndim > 0 and v.shape[0] != 1:
# ``add_batch_dim`` already ran inside the preprocessor;
# an unbatched tensor at this point means a step
# somewhere produced an unbatched output. Best-effort
# fix.
if v.shape[0] != 1 and v.ndim < 4 and "image" not in k:
observation[k] = v.unsqueeze(0)
# Move to device (the preprocessor's DeviceProcessorStep should
# already have done this when ``preprocessor is not None``;
# this is a belt-and-braces no-op in the common case).
for k, v in list(observation.items()):
if isinstance(v, torch.Tensor):
observation[k] = v.to(device)
return observation
return _provider
def _build_tools(no_tts: bool, tts_voice: str) -> dict[str, Any]:
"""Instantiate the tools declared on this dataset/policy."""
if no_tts:
return {}
@@ -140,20 +314,32 @@ def main(argv: list[str] | None = None) -> int:
format="%(asctime)s %(levelname)s %(message)s",
)
if not args.policy_path.exists():
print(f"[smolvla2] policy path not found: {args.policy_path}", file=sys.stderr)
return 1
print(f"[smolvla2] loading policy from {args.policy_path}", flush=True)
policy = _load_policy(args.policy_path)
policy, preprocessor, _ds_meta = _load_policy_and_preprocessor(
args.policy_path, args.dataset_repo_id
)
tools = _build_tools(args.policy_path, args.no_tts, args.tts_voice)
observation_provider: Callable[[], dict | None] | None = None
if args.dataset_repo_id is not None:
print(
f"[smolvla2] streaming observations from {args.dataset_repo_id} "
f"episode={args.dataset_episode} "
f"start_frame={args.dataset_start_frame}",
flush=True,
)
observation_provider = _build_observation_provider(
dataset_repo_id=args.dataset_repo_id,
episode=args.dataset_episode,
start_frame=args.dataset_start_frame,
advance_per_tick=args.dataset_advance_per_tick,
preprocessor=preprocessor,
device=str(getattr(policy.config, "device", "cpu")),
)
tools = _build_tools(args.no_tts, args.tts_voice)
if tools:
print(f"[smolvla2] tools loaded: {list(tools)}", flush=True)
# Robot wiring is left as a follow-up — for v1 we run language-only
# / dry-run so REPL development doesn't require a connected robot.
observation_provider = None
robot_executor = None
if not args.no_robot:
print(
@@ -162,33 +348,112 @@ def main(argv: list[str] | None = None) -> int:
flush=True,
)
from lerobot.policies.smolvla2.inference import ( # noqa: PLC0415
SmolVLA2Runtime,
StdinReader,
)
from lerobot.policies.smolvla2.inference import SmolVLA2Runtime # noqa: PLC0415
runtime = SmolVLA2Runtime(
policy=policy,
tools=tools,
observation_provider=observation_provider,
robot_executor=robot_executor,
event_collector=StdinReader().poll,
# No background event collector — the REPL drives ticks
# synchronously after each user input. The runtime's own
# ``run()`` loop is bypassed here in favour of ``step_once()``
# so the input prompt and the live state panel co-exist
# cleanly.
event_collector=None,
chunk_hz=args.chunk_hz,
ctrl_hz=args.ctrl_hz,
high_level_hz=args.high_level_hz,
)
if args.task:
runtime.set_task(args.task)
print(
"[smolvla2] runtime ready. Type a task to begin, then any line for "
"interjections, questions ending in '?' for VQA, or 'stop' to exit.",
flush=True,
)
return _run_repl(runtime, initial_task=args.task, max_ticks=args.max_ticks)
def _run_repl(runtime: Any, *, initial_task: str | None, max_ticks: int | None) -> int:
"""Two-zone TUI: chat scrollback above a persistent state panel.
Uses :class:`rich.live.Live` to keep the state panel rendered
below the chat history. ``console.input`` for the prompt — Live
auto-suspends repaint while the user is typing.
"""
try:
runtime.run(max_ticks=args.max_ticks)
except KeyboardInterrupt:
runtime.stop()
print("\n[smolvla2] interrupted by user", flush=True)
from rich.console import Console # noqa: PLC0415
from rich.live import Live # noqa: PLC0415
except ImportError:
print(
"[smolvla2] rich is required for the interactive REPL. "
"`pip install rich` and re-run.",
file=sys.stderr,
)
return 2
from lerobot.policies.smolvla2.inference import ( # noqa: PLC0415
make_state_panel,
print_robot_lines,
print_user_line,
)
console = Console()
console.print(
"[bold]SmolVLA2[/] ready. "
"Type a task to begin, then any line for an interjection, "
"a line ending in '?' for VQA, or 'stop' to exit.",
)
if initial_task is None:
console.print("[dim]No --task provided; first stdin line will be used.[/]")
panel = make_state_panel(runtime.state)
ticks_done = 0
with Live(
panel,
console=console,
refresh_per_second=4,
transient=False,
screen=False,
) as live:
try:
while True:
try:
line = console.input("[bold cyan]>[/] ").strip()
except EOFError:
break
if not line:
continue
lower = line.lower()
if lower in {"stop", "quit", "exit"}:
break
print_user_line(live.console, line)
# Inject the user input as the right kind of event,
# then run a single pipeline tick to consume it.
if not runtime.state.get("task"):
task = line[5:].strip() if lower.startswith("task:") else line
runtime.set_task(task)
elif lower.endswith("?"):
runtime.state["recent_vqa_query"] = line
runtime.state.setdefault("events_this_tick", []).append(
"user_vqa_query"
)
else:
runtime.state["recent_interjection"] = line
runtime.state.setdefault("events_this_tick", []).append(
"user_interjection"
)
logs = runtime.step_once()
if logs:
print_robot_lines(live.console, logs)
live.update(make_state_panel(runtime.state))
ticks_done += 1
if max_ticks is not None and ticks_done >= max_ticks:
break
except KeyboardInterrupt:
console.print("\n[smolvla2] interrupted", style="dim")
print("[smolvla2] runtime stopped", flush=True)
return 0