From 3fe686ce9f5e71e998782cdccd1f34f21b9ed38f Mon Sep 17 00:00:00 2001 From: Pepijn Date: Tue, 5 May 2026 11:08:53 +0200 Subject: [PATCH] feat(smolvla2): runtime accepts Hub IDs + dataset-driven dry-run MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The runtime CLI's loader was broken — it imported a `make_policy_from_path` that doesn't exist in `lerobot.policies.factory` — and the high-level text steps generated plan / subtask / memory / VQA from a text-only batch with no images or state, so dry-runs drifted from the training distribution. Switch to the standard `PreTrainedConfig.from_pretrained` + `make_policy(cfg, ds_meta=...)` flow so `--policy.path` accepts both local directories and Hub repo ids, and add a `--dataset.repo_id` path that walks a chosen episode and feeds preprocessed observations into every forward pass — including the four high-level steps (`HighLevelSubtaskFwd`, `MemoryUpdateFwd`, `UserInterjectionFwd`, `AskVQAFwd`). Frames are routed through the saved preprocessor pipeline with `language_persistent` / `language_events` stripped so the recipe-render step stays a no-op (the runtime supplies its own messages from current state). Also wires the rich-based two-zone REPL layout (`ui.py`) that the script was already importing. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../policies/smolvla2/inference/__init__.py | 5 + .../policies/smolvla2/inference/runtime.py | 49 ++- .../policies/smolvla2/inference/steps.py | 69 +++- src/lerobot/policies/smolvla2/inference/ui.py | 119 +++++++ .../scripts/lerobot_smolvla2_runtime.py | 335 ++++++++++++++++-- 5 files changed, 520 insertions(+), 57 deletions(-) create mode 100644 src/lerobot/policies/smolvla2/inference/ui.py diff --git a/src/lerobot/policies/smolvla2/inference/__init__.py b/src/lerobot/policies/smolvla2/inference/__init__.py index c65c301d5..30f77635a 100644 --- a/src/lerobot/policies/smolvla2/inference/__init__.py +++ b/src/lerobot/policies/smolvla2/inference/__init__.py @@ -40,6 +40,7 @@ from .steps import ( UserInterjectionFwd, ) from .triggers import EventTrigger, HzTrigger, Tick, TickClock, Trigger +from .ui import make_state_panel, print_robot_lines, print_user_line __all__ = [ # runtime @@ -65,4 +66,8 @@ __all__ = [ "UserInterjectionFwd", "AskVQAFwd", "DispatchToolCalls", + # UI + "make_state_panel", + "print_robot_lines", + "print_user_line", ] diff --git a/src/lerobot/policies/smolvla2/inference/runtime.py b/src/lerobot/policies/smolvla2/inference/runtime.py index 4b78f030f..f888b8bc3 100644 --- a/src/lerobot/policies/smolvla2/inference/runtime.py +++ b/src/lerobot/policies/smolvla2/inference/runtime.py @@ -82,10 +82,20 @@ class SmolVLA2Runtime: HighLevelSubtaskFwd( trigger=HzTrigger(self.high_level_hz), policy=self.policy, + observation_provider=self.observation_provider, + ), + MemoryUpdateFwd( + policy=self.policy, + observation_provider=self.observation_provider, + ), + UserInterjectionFwd( + policy=self.policy, + observation_provider=self.observation_provider, + ), + AskVQAFwd( + policy=self.policy, + observation_provider=self.observation_provider, ), - MemoryUpdateFwd(policy=self.policy), - UserInterjectionFwd(policy=self.policy), - AskVQAFwd(policy=self.policy), DispatchToolCalls(tools=self.tools), ] self.state = initial_runtime_state() @@ -127,6 +137,39 @@ class SmolVLA2Runtime: self._on_shutdown() + # ------------------------------------------------------------------ + # REPL helper: drive one full pipeline pass and return its logs + # ------------------------------------------------------------------ + + def step_once(self) -> list[str]: + """Run one tick of the pipeline and return the log lines. + + Used by the interactive REPL: instead of a background thread, + the CLI drives ticks synchronously after each user input. Logs + are returned (not printed) so the caller can route them into + the rich-Live chat scrollback. + """ + from .triggers import Tick # noqa: PLC0415 + + # Synthesize a tick. We don't need the real wall-clock pacing + # here — the REPL drives the runtime, not vice versa — but + # ``HzTrigger`` uses ``tick.monotonic_seconds`` to gate, so we + # bump it generously so every Hz-triggered step considers + # itself due. + import time as _time # noqa: PLC0415 + + prev_index = self.state.get("_tick").index if isinstance(self.state.get("_tick"), Tick) else 0 + self.state["_tick"] = Tick(index=prev_index + 1, monotonic_seconds=_time.monotonic()) + self.state["log_lines"] = [] + # ``events_this_tick`` is set up by the caller before + # ``step_once`` (the REPL pushes user-driven events first). + self.state.setdefault("events_this_tick", []) + + for step in self.pipeline: + self.state = step(self.state) + + return list(self.state.get("log_lines") or []) + # ------------------------------------------------------------------ # I/O # ------------------------------------------------------------------ diff --git a/src/lerobot/policies/smolvla2/inference/steps.py b/src/lerobot/policies/smolvla2/inference/steps.py index ca36b1854..0841359b0 100644 --- a/src/lerobot/policies/smolvla2/inference/steps.py +++ b/src/lerobot/policies/smolvla2/inference/steps.py @@ -171,13 +171,19 @@ class HighLevelSubtaskFwd(InferenceStep): """At ~1 Hz, ask the policy for the next subtask.""" policy: Any = None + observation_provider: Any = None + """Same shape as ``LowLevelForward.observation_provider``. When + set, the resulting observation is merged into ``select_message``'s + batch so text generation runs against real video + state.""" + trigger: Trigger = field(default_factory=lambda: HzTrigger(hz=1.0)) def run(self, state: dict[str, Any]) -> dict[str, Any] | None: if self.policy is None or not state.get("task"): return None ctx = _control_context_messages(state) - msg = _generate_with_policy(self.policy, ctx) + observation = _maybe_observation(self.observation_provider) + msg = _generate_with_policy(self.policy, ctx, observation=observation) if msg: changed = set_if_changed(state, "current_subtask", msg, label="subtask") if changed: @@ -191,6 +197,7 @@ class MemoryUpdateFwd(InferenceStep): """On subtask boundary, refresh the compressed memory.""" policy: Any = None + observation_provider: Any = None trigger: Trigger = field(default_factory=lambda: EventTrigger("subtask_change")) def run(self, state: dict[str, Any]) -> dict[str, Any] | None: @@ -198,7 +205,8 @@ class MemoryUpdateFwd(InferenceStep): if self.policy is None: return None ctx = _control_context_messages(state, include_completed=True) - new_memory = _generate_with_policy(self.policy, ctx) + observation = _maybe_observation(self.observation_provider) + new_memory = _generate_with_policy(self.policy, ctx, observation=observation) if new_memory: set_if_changed(state, "current_memory", new_memory, label="memory") return None @@ -209,6 +217,7 @@ class UserInterjectionFwd(InferenceStep): """On stdin interjection, refresh the plan + emit a paired ``say``.""" policy: Any = None + observation_provider: Any = None trigger: Trigger = field(default_factory=lambda: EventTrigger("user_interjection")) def run(self, state: dict[str, Any]) -> dict[str, Any] | None: @@ -218,7 +227,8 @@ class UserInterjectionFwd(InferenceStep): state, extra_user=state.get("recent_interjection"), ) - out = _generate_with_policy(self.policy, ctx) + observation = _maybe_observation(self.observation_provider) + out = _generate_with_policy(self.policy, ctx, observation=observation) if not out: return None # Heuristic split: model is trained to emit one assistant turn @@ -247,6 +257,7 @@ class AskVQAFwd(InferenceStep): """On stdin question, answer a frame-grounded VQA.""" policy: Any = None + observation_provider: Any = None trigger: Trigger = field(default_factory=lambda: EventTrigger("user_vqa_query")) def run(self, state: dict[str, Any]) -> dict[str, Any] | None: @@ -256,7 +267,8 @@ class AskVQAFwd(InferenceStep): if not question: return None ctx = _control_context_messages(state, extra_user=question) - answer = _generate_with_policy(self.policy, ctx) + observation = _maybe_observation(self.observation_provider) + answer = _generate_with_policy(self.policy, ctx, observation=observation) if answer: push_log(state, f" vqa: {answer}") state["recent_vqa_query"] = None @@ -326,35 +338,54 @@ def _control_context_messages( return msgs -def _generate_with_policy(policy: Any, messages: list[dict[str, Any]]) -> str: - """Drive ``policy.select_message`` with a minimal text-only batch. +def _maybe_observation(provider: Any) -> dict | None: + """Pull one observation from ``provider`` if it's set, else ``None``. - Best-effort: the runtime today doesn't construct a full - observation batch with images / state for text generation; the - text-head was trained over images + lang + state, so generations - here may differ in distribution from training. This is acceptable - for a v1 REPL; a follow-up will plug in the real observation. + Errors from the provider are logged at debug level and swallowed — + text generation still runs (in text-only mode) so a flaky frame + source doesn't kill the REPL. + """ + if provider is None: + return None + try: + return provider() + except Exception as exc: # noqa: BLE001 + logger.debug("observation_provider raised %s — falling back to text-only", exc) + return None + + +def _generate_with_policy( + policy: Any, + messages: list[dict[str, Any]], + *, + observation: dict | None = None, +) -> str: + """Drive ``policy.select_message`` with a chat batch (and optional obs). + + When ``observation`` carries ``observation.images.*`` and + ``observation.state``, those are merged into the batch so + ``select_message`` runs the same VLM prefix the policy was trained + on. Without an observation the runtime falls back to a text-only + prompt — the text head still runs, but generations may drift from + the training distribution. """ if not hasattr(policy, "select_message"): return "" text_batch = _build_text_batch(policy, messages) - # ``select_message`` expects a real batch with OBS_LANGUAGE_TOKENS. - # The minimal text-only batch we build doesn't have images / state, - # so we either run a text-only forward (handled by SmolVLA2 when - # supported) or skip and return empty. v1 returns empty when the - # policy can't handle it; the runtime logs and continues. try: - # Convert to the OBS_LANGUAGE_TOKENS / OBS_LANGUAGE_ATTENTION_MASK - # keys ``select_message`` uses internally. from lerobot.utils.constants import ( # noqa: PLC0415 OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, ) - batch = { + batch: dict[str, Any] = { OBS_LANGUAGE_TOKENS: text_batch["lang_tokens"], OBS_LANGUAGE_ATTENTION_MASK: text_batch["lang_masks"], } + if observation: + for k, v in observation.items(): + if isinstance(k, str) and k.startswith("observation.") and k not in batch: + batch[k] = v return policy.select_message(batch, tokenizer=text_batch["tokenizer"]) except Exception as exc: # noqa: BLE001 logger.debug("select_message fell back: %s", exc) diff --git a/src/lerobot/policies/smolvla2/inference/ui.py b/src/lerobot/policies/smolvla2/inference/ui.py new file mode 100644 index 000000000..692333f21 --- /dev/null +++ b/src/lerobot/policies/smolvla2/inference/ui.py @@ -0,0 +1,119 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Rich-based REPL layout for the SmolVLA2 runtime. + +Two-zone terminal layout: + + [chat scrollback — user messages / robot responses, scrolls naturally] + + ┌── State ──────────────────────────────────────────┐ + │ task please clean up the kitchen │ + │ subtask grasp the handle of the sponge │ + │ plan 1. grasp sponge 2. wipe 3. tidy │ + │ memory sponge picked up; counter still dirty │ + └───────────────────────────────────────────────────┘ + > _ + +The state panel re-renders on every state change. Chat lines are +``console.print``'d above the live region so they accumulate naturally +in scrollback. Implemented with :class:`rich.live.Live` plus +:func:`rich.console.Console.input` for the prompt — when an input is +pending, ``rich.Live`` auto-suspends so the input doesn't fight the +panel for cursor position. +""" + +from __future__ import annotations + +from typing import Any + +try: # rich is optional; only required for the interactive REPL. + from rich.console import Console + from rich.panel import Panel + from rich.table import Table + from rich.text import Text + + _HAS_RICH = True +except ImportError: # pragma: no cover + _HAS_RICH = False + Console = Any # type: ignore[assignment] + Panel = Any # type: ignore[assignment] + Table = Any # type: ignore[assignment] + Text = Any # type: ignore[assignment] + + +_STATE_KEYS = ( + ("task", "task"), + ("current_subtask", "subtask"), + ("current_plan", "plan"), + ("current_memory", "memory"), +) + + +def make_state_panel(state: dict[str, Any]) -> Any: + """Render the persistent state panel for the live region. + + Returns a :class:`rich.panel.Panel`. Caller passes it to + ``Live.update(panel)`` whenever the state changes. + """ + if not _HAS_RICH: + raise RuntimeError( + "rich is required for the interactive REPL. " + "`pip install rich` (it's a transitive dep of lerobot)." + ) + table = Table.grid(padding=(0, 2), expand=True) + table.add_column(justify="right", style="dim", no_wrap=True, width=10) + table.add_column(justify="left") + for key, label in _STATE_KEYS: + value = state.get(key) + if value is None: + rendered = Text("(not set)", style="dim italic") + else: + rendered = Text(str(value), style="bold") + table.add_row(label, rendered) + queue = state.get("action_queue") + queue_len = len(queue) if hasattr(queue, "__len__") else 0 + pending = state.get("tool_calls_pending") or [] + footer = Text.assemble( + ("queued actions: ", "dim"), + (str(queue_len), "bold cyan"), + (" pending tool calls: ", "dim"), + (str(len(pending)), "bold magenta"), + ) + table.add_row("", footer) + return Panel(table, title="[bold]SmolVLA2 state[/]", border_style="cyan") + + +def print_user_line(console: Any, line: str) -> None: + """Append a user-typed line to the chat scrollback.""" + if not _HAS_RICH: + print(f"you: {line}", flush=True) + return + console.print(f"[bold cyan]you:[/] {line}") + + +def print_robot_lines(console: Any, lines: list[str]) -> None: + """Append robot/runtime log lines to the chat scrollback.""" + if not _HAS_RICH: + for line in lines: + print(f"robot: {line.lstrip()}", flush=True) + return + for line in lines: + # The runtime uses leading whitespace + "label: text"; render + # the label in green and the value in default for readability. + stripped = line.lstrip() + if ":" in stripped: + label, _, value = stripped.partition(":") + console.print(f"[bold green]robot[/] [dim]({label.strip()})[/] {value.strip()}") + else: + console.print(f"[bold green]robot:[/] {stripped}") diff --git a/src/lerobot/scripts/lerobot_smolvla2_runtime.py b/src/lerobot/scripts/lerobot_smolvla2_runtime.py index f01933146..941da8a49 100644 --- a/src/lerobot/scripts/lerobot_smolvla2_runtime.py +++ b/src/lerobot/scripts/lerobot_smolvla2_runtime.py @@ -23,11 +23,21 @@ speech) as they happen. Examples -------- -Dry run on a checkpoint, no robot connected — useful for sanity- +Dry run on a Hub checkpoint, no robot connected — useful for sanity- checking text generation:: uv run lerobot-smolvla2-runtime \\ - --policy.path=outputs/train/smolvla2_super_poulain/000020000/pretrained_model \\ + --policy.path=pepijn223/smolvla2_hirobot_super_poulain_tool2 \\ + --no_robot \\ + --task="please clean the kitchen" + +Same, but feed real frames from an annotated dataset so plan / subtask +/ memory / VQA generation runs against actual video + state:: + + uv run lerobot-smolvla2-runtime \\ + --policy.path=pepijn223/smolvla2_hirobot_super_poulain_tool2 \\ + --dataset.repo_id=pepijn223/super_poulain_annotated \\ + --dataset.episode=0 \\ --no_robot \\ --task="please clean the kitchen" @@ -38,6 +48,9 @@ With a real robot:: --robot.type=so101 --robot.port=/dev/tty.usbmodem... \\ --tts.voice=alba +``--policy.path`` accepts either a local directory or a Hugging Face +Hub repo id. ``--dataset.repo_id`` likewise. + Tool dispatch (TTS via ``SayTool``) is enabled by default when ``pocket-tts`` is installed; pass ``--no_tts`` to disable. """ @@ -47,8 +60,7 @@ from __future__ import annotations import argparse import logging import sys -from pathlib import Path -from typing import Any +from typing import Any, Callable logger = logging.getLogger("lerobot.smolvla2.runtime") @@ -61,9 +73,51 @@ def _parse_args(argv: list[str] | None = None) -> argparse.Namespace: p.add_argument( "--policy.path", dest="policy_path", - type=Path, + type=str, required=True, - help="Path to a trained SmolVLA2 ``pretrained_model`` directory.", + help=( + "Local directory or Hugging Face Hub repo id pointing at a " + "trained SmolVLA2 ``pretrained_model``." + ), + ) + p.add_argument( + "--dataset.repo_id", + dest="dataset_repo_id", + type=str, + default=None, + help=( + "Optional dataset (local path or Hub repo id) used to drive " + "observations during dry-run inference. When set, the runtime " + "reads camera frames + state from the chosen episode and feeds " + "them into all forward passes — so plan / subtask / memory / " + "VQA generation see the same visual context the policy was " + "trained on." + ), + ) + p.add_argument( + "--dataset.episode", + dest="dataset_episode", + type=int, + default=0, + help="Episode index to walk through (default: 0).", + ) + p.add_argument( + "--dataset.start_frame", + dest="dataset_start_frame", + type=int, + default=0, + help="Frame index within the episode to start from (default: 0).", + ) + p.add_argument( + "--dataset.advance_per_tick", + dest="dataset_advance_per_tick", + type=int, + default=1, + help=( + "How many dataset frames to advance per runtime tick. The " + "default of 1 means the runtime walks the episode forward " + "frame by frame; set to 0 to freeze on ``start_frame``." + ), ) p.add_argument( "--task", @@ -111,16 +165,136 @@ def _parse_args(argv: list[str] | None = None) -> argparse.Namespace: return p.parse_args(argv) -def _load_policy(path: Path): # noqa: ANN202 - """Load a SmolVLA2 checkpoint from ``path``.""" - from lerobot.policies.factory import make_policy_from_path # noqa: PLC0415 +def _load_policy_and_preprocessor( + policy_path: str, + dataset_repo_id: str | None, +) -> tuple[Any, Any, Any]: + """Load a SmolVLA2 checkpoint (local path or Hub repo id). + + When ``dataset_repo_id`` is provided, the dataset's metadata is used + to derive policy features (matching the standard + ``make_policy(cfg, ds_meta=...)`` flow used by ``lerobot-train`` and + ``lerobot-record``). When it isn't, we fall back to instantiating + the policy directly via ``from_pretrained`` — this skips the + feature-derivation path that ``make_policy`` insists on, but also + means we can't load the saved preprocessor pipeline (which depends + on ``input_features`` / ``output_features``). For inference-only + dry-runs this is fine; the policy still loads. + + Returns ``(policy, preprocessor, ds_meta)`` where ``preprocessor`` + and ``ds_meta`` may be ``None`` if no dataset was provided. + """ + from lerobot.configs import PreTrainedConfig # noqa: PLC0415 + from lerobot.policies.factory import make_policy, make_pre_post_processors # noqa: PLC0415 + + cfg = PreTrainedConfig.from_pretrained(policy_path) + cfg.pretrained_path = policy_path + + ds_meta = None + preprocessor = None + if dataset_repo_id is not None: + from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata # noqa: PLC0415 + + ds_meta = LeRobotDatasetMetadata(dataset_repo_id) + policy = make_policy(cfg, ds_meta=ds_meta) + preprocessor, _ = make_pre_post_processors( + cfg, + pretrained_path=policy_path, + dataset_stats=ds_meta.stats, + ) + else: + # No dataset: instantiate the policy class directly so we don't + # need ds_meta. This bypasses ``make_policy``'s feature-shape + # derivation, which is fine for a pretrained checkpoint where + # the saved config already carries those shapes. + from lerobot.policies.factory import get_policy_class # noqa: PLC0415 + + policy_cls = get_policy_class(cfg.type) + policy = policy_cls.from_pretrained(policy_path, config=cfg) + policy.to(cfg.device) - policy = make_policy_from_path(str(path)) policy.eval() - return policy + return policy, preprocessor, ds_meta -def _build_tools(policy_path: Path, no_tts: bool, tts_voice: str) -> dict[str, Any]: +def _build_observation_provider( + *, + dataset_repo_id: str, + episode: int, + start_frame: int, + advance_per_tick: int, + preprocessor: Any, + device: str, +) -> Callable[[], dict | None]: + """Build a closure that feeds dataset frames into the runtime. + + Each call returns a preprocessed observation batch (images + + state, batched, on the policy's device, normalized) suitable for + ``policy.select_action`` and ``policy.select_message``. The + closure walks the chosen episode forward by ``advance_per_tick`` + frames per call, looping back to the episode start when it falls + off the end. + + The dataset's ``language_persistent`` / ``language_events`` + columns are stripped before the sample reaches the preprocessor, + so ``RenderMessagesStep`` and ``SmolVLA2ChatTokenizerStep`` are + no-ops; the runtime supplies its own messages from current state. + """ + import torch # noqa: PLC0415 + + from lerobot.datasets.lerobot_dataset import LeRobotDataset # noqa: PLC0415 + + ds = LeRobotDataset(dataset_repo_id, episodes=[episode]) + if len(ds) == 0: + raise ValueError( + f"Dataset {dataset_repo_id!r} episode {episode} is empty." + ) + + state = {"cursor": max(0, min(start_frame, len(ds) - 1))} + + def _provider() -> dict | None: + idx = state["cursor"] + if advance_per_tick > 0: + state["cursor"] = (idx + advance_per_tick) % len(ds) + + sample = ds[idx] + # Strip the language columns so the preprocessor's render step + # is a no-op — the runtime drives messages itself. + for k in ("language_persistent", "language_events"): + sample.pop(k, None) + + if preprocessor is not None: + sample = preprocessor(sample) + + # Keep only observation keys; the runtime's text path will + # merge these with its own lang_tokens / lang_masks. + observation = { + k: v + for k, v in sample.items() + if isinstance(k, str) and k.startswith("observation.") + } + # Defensive: if something further upstream forgot the batch + # dim, add it now so downstream Tensor ops don't crash. + for k, v in list(observation.items()): + if isinstance(v, torch.Tensor) and v.ndim > 0 and v.shape[0] != 1: + # ``add_batch_dim`` already ran inside the preprocessor; + # an unbatched tensor at this point means a step + # somewhere produced an unbatched output. Best-effort + # fix. + if v.shape[0] != 1 and v.ndim < 4 and "image" not in k: + observation[k] = v.unsqueeze(0) + # Move to device (the preprocessor's DeviceProcessorStep should + # already have done this when ``preprocessor is not None``; + # this is a belt-and-braces no-op in the common case). + for k, v in list(observation.items()): + if isinstance(v, torch.Tensor): + observation[k] = v.to(device) + return observation + + return _provider + + +def _build_tools(no_tts: bool, tts_voice: str) -> dict[str, Any]: """Instantiate the tools declared on this dataset/policy.""" if no_tts: return {} @@ -140,20 +314,32 @@ def main(argv: list[str] | None = None) -> int: format="%(asctime)s %(levelname)s %(message)s", ) - if not args.policy_path.exists(): - print(f"[smolvla2] policy path not found: {args.policy_path}", file=sys.stderr) - return 1 - print(f"[smolvla2] loading policy from {args.policy_path}", flush=True) - policy = _load_policy(args.policy_path) + policy, preprocessor, _ds_meta = _load_policy_and_preprocessor( + args.policy_path, args.dataset_repo_id + ) - tools = _build_tools(args.policy_path, args.no_tts, args.tts_voice) + observation_provider: Callable[[], dict | None] | None = None + if args.dataset_repo_id is not None: + print( + f"[smolvla2] streaming observations from {args.dataset_repo_id} " + f"episode={args.dataset_episode} " + f"start_frame={args.dataset_start_frame}", + flush=True, + ) + observation_provider = _build_observation_provider( + dataset_repo_id=args.dataset_repo_id, + episode=args.dataset_episode, + start_frame=args.dataset_start_frame, + advance_per_tick=args.dataset_advance_per_tick, + preprocessor=preprocessor, + device=str(getattr(policy.config, "device", "cpu")), + ) + + tools = _build_tools(args.no_tts, args.tts_voice) if tools: print(f"[smolvla2] tools loaded: {list(tools)}", flush=True) - # Robot wiring is left as a follow-up — for v1 we run language-only - # / dry-run so REPL development doesn't require a connected robot. - observation_provider = None robot_executor = None if not args.no_robot: print( @@ -162,33 +348,112 @@ def main(argv: list[str] | None = None) -> int: flush=True, ) - from lerobot.policies.smolvla2.inference import ( # noqa: PLC0415 - SmolVLA2Runtime, - StdinReader, - ) + from lerobot.policies.smolvla2.inference import SmolVLA2Runtime # noqa: PLC0415 runtime = SmolVLA2Runtime( policy=policy, tools=tools, observation_provider=observation_provider, robot_executor=robot_executor, - event_collector=StdinReader().poll, + # No background event collector — the REPL drives ticks + # synchronously after each user input. The runtime's own + # ``run()`` loop is bypassed here in favour of ``step_once()`` + # so the input prompt and the live state panel co-exist + # cleanly. + event_collector=None, chunk_hz=args.chunk_hz, ctrl_hz=args.ctrl_hz, high_level_hz=args.high_level_hz, ) if args.task: runtime.set_task(args.task) - print( - "[smolvla2] runtime ready. Type a task to begin, then any line for " - "interjections, questions ending in '?' for VQA, or 'stop' to exit.", - flush=True, - ) + + return _run_repl(runtime, initial_task=args.task, max_ticks=args.max_ticks) + + +def _run_repl(runtime: Any, *, initial_task: str | None, max_ticks: int | None) -> int: + """Two-zone TUI: chat scrollback above a persistent state panel. + + Uses :class:`rich.live.Live` to keep the state panel rendered + below the chat history. ``console.input`` for the prompt — Live + auto-suspends repaint while the user is typing. + """ try: - runtime.run(max_ticks=args.max_ticks) - except KeyboardInterrupt: - runtime.stop() - print("\n[smolvla2] interrupted by user", flush=True) + from rich.console import Console # noqa: PLC0415 + from rich.live import Live # noqa: PLC0415 + except ImportError: + print( + "[smolvla2] rich is required for the interactive REPL. " + "`pip install rich` and re-run.", + file=sys.stderr, + ) + return 2 + + from lerobot.policies.smolvla2.inference import ( # noqa: PLC0415 + make_state_panel, + print_robot_lines, + print_user_line, + ) + + console = Console() + console.print( + "[bold]SmolVLA2[/] ready. " + "Type a task to begin, then any line for an interjection, " + "a line ending in '?' for VQA, or 'stop' to exit.", + ) + if initial_task is None: + console.print("[dim]No --task provided; first stdin line will be used.[/]") + + panel = make_state_panel(runtime.state) + ticks_done = 0 + with Live( + panel, + console=console, + refresh_per_second=4, + transient=False, + screen=False, + ) as live: + try: + while True: + try: + line = console.input("[bold cyan]>[/] ").strip() + except EOFError: + break + if not line: + continue + lower = line.lower() + if lower in {"stop", "quit", "exit"}: + break + + print_user_line(live.console, line) + + # Inject the user input as the right kind of event, + # then run a single pipeline tick to consume it. + if not runtime.state.get("task"): + task = line[5:].strip() if lower.startswith("task:") else line + runtime.set_task(task) + elif lower.endswith("?"): + runtime.state["recent_vqa_query"] = line + runtime.state.setdefault("events_this_tick", []).append( + "user_vqa_query" + ) + else: + runtime.state["recent_interjection"] = line + runtime.state.setdefault("events_this_tick", []).append( + "user_interjection" + ) + + logs = runtime.step_once() + if logs: + print_robot_lines(live.console, logs) + live.update(make_state_panel(runtime.state)) + + ticks_done += 1 + if max_ticks is not None and ticks_done >= max_ticks: + break + except KeyboardInterrupt: + console.print("\n[smolvla2] interrupted", style="dim") + print("[smolvla2] runtime stopped", flush=True) return 0