From 26cb38a7d0eb215b3bcdb055e5004255c4cc165f Mon Sep 17 00:00:00 2001 From: Pepijn Date: Mon, 18 May 2026 11:20:57 +0200 Subject: [PATCH] feat(smolvla2): startup task picker, /vlm mode toggle, interactive VQA overlay MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three additions to the SmolVLA2 interactive runtime: 1. Startup task picker — when no --task is given, the runtime lists the dataset's task strings as a numbered menu (plus a custom-task option) instead of silently waiting for the first stdin line. 2. Mode toggle — /action and /vlm slash commands flip a persistent run mode. /vlm pauses the whole action loop (HighLevelSubtaskFwd, LowLevelForward and DispatchAction gate on state["mode"]) and clears the action queue so the robot holds position; /action resumes it. The mode is shown in the state panel. 3. Interactive VQA — in /vlm mode a typed line is a VQA question. The new inference/vqa.py module asks which camera to ground on, runs the VLM on that single camera, and when the answer is a bbox/keypoint it draws the overlay, saves a PNG to ./vqa_overlays/ and auto-opens it. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../policies/smolvla2/inference/repl.py | 18 + .../smolvla2/inference/runtime_state.py | 4 + .../policies/smolvla2/inference/steps.py | 14 + src/lerobot/policies/smolvla2/inference/ui.py | 10 +- .../policies/smolvla2/inference/vqa.py | 339 ++++++++++++++++++ .../scripts/lerobot_smolvla2_runtime.py | 175 ++++++++- .../smolvla/test_smolvla2_vqa_overlay.py | 179 +++++++++ 7 files changed, 734 insertions(+), 5 deletions(-) create mode 100644 src/lerobot/policies/smolvla2/inference/vqa.py create mode 100644 tests/policies/smolvla/test_smolvla2_vqa_overlay.py diff --git a/src/lerobot/policies/smolvla2/inference/repl.py b/src/lerobot/policies/smolvla2/inference/repl.py index 6afc0ef98..7ab84dbf2 100644 --- a/src/lerobot/policies/smolvla2/inference/repl.py +++ b/src/lerobot/policies/smolvla2/inference/repl.py @@ -16,11 +16,16 @@ Reads non-blocking stdin lines, classifies each one heuristically: "stop" / "quit" / "exit" → state["stop"] = True + "/action" / "/vlm" → set state["mode"] ends with "?" → user_vqa_query event starts with "task:" or first line → set runtime task anything else → user_interjection event Plugged into the runtime via ``event_collector=StdinReader().poll``. + +Note: the shipped CLI (``lerobot-smolvla2-runtime``) drives stdin +directly in its REPL / autonomous loops and does *not* wire this +collector; it's kept as the documented embedding hook and for tests. """ from __future__ import annotations @@ -70,6 +75,19 @@ class StdinReader: state["stop"] = True return + # Slash commands flip the run mode. ``/vlm`` pauses the action + # loop (the action steps gate on ``state["mode"]``); ``/action`` + # resumes it. + if lower in {"/action", "/act"}: + state["mode"] = "action" + return + if lower in {"/vlm", "/vqa"}: + state["mode"] = "vlm" + queue = state.get("action_queue") + if hasattr(queue, "clear"): + queue.clear() + return + # First non-control line sets the task if no task is active. if not state.get("task"): task = line[5:].strip() if lower.startswith("task:") else line diff --git a/src/lerobot/policies/smolvla2/inference/runtime_state.py b/src/lerobot/policies/smolvla2/inference/runtime_state.py index 978a2c83e..49f2f8874 100644 --- a/src/lerobot/policies/smolvla2/inference/runtime_state.py +++ b/src/lerobot/policies/smolvla2/inference/runtime_state.py @@ -33,6 +33,9 @@ Stable keys (read by multiple steps): events_this_tick list[str] triggers consumed this tick _tick Tick current tick (set by the loop) + mode str "action" (run the robot) | "vlm" (VQA only, + action loop paused) + log_lines list[str] human-readable status lines printed each tick """ @@ -54,6 +57,7 @@ def initial_runtime_state(task: str | None = None) -> dict[str, Any]: "tool_calls_pending": [], "events_this_tick": [], "log_lines": [], + "mode": "action", "stop": False, } diff --git a/src/lerobot/policies/smolvla2/inference/steps.py b/src/lerobot/policies/smolvla2/inference/steps.py index a36ae26b5..e37604924 100644 --- a/src/lerobot/policies/smolvla2/inference/steps.py +++ b/src/lerobot/policies/smolvla2/inference/steps.py @@ -91,6 +91,10 @@ class LowLevelForward(InferenceStep): def run(self, state: dict[str, Any]) -> dict[str, Any] | None: if self.policy is None or self.observation_provider is None: return None + # ``/vlm`` mode pauses the whole action loop so the robot holds + # position while the operator probes the VLM with VQA. + if state.get("mode", "action") != "action": + return None if not state.get("task"): return None @@ -197,6 +201,12 @@ class DispatchAction(InferenceStep): def run(self, state: dict[str, Any]) -> dict[str, Any] | None: import time as _time # noqa: PLC0415 + # ``/vlm`` mode pauses dispatch — the robot holds its last + # commanded position while the operator runs VQA. + if state.get("mode", "action") != "action": + self._last_dispatch_t = None + return None + queue = state.get("action_queue") if not queue: # Reset wall-clock anchor when the queue is empty so the @@ -366,6 +376,10 @@ class HighLevelSubtaskFwd(InferenceStep): def run(self, state: dict[str, Any]) -> dict[str, Any] | None: if self.policy is None or not state.get("task"): return None + # ``/vlm`` mode pauses subtask generation along with the rest of + # the action loop. + if state.get("mode", "action") != "action": + return None # Gate to chunk boundaries: only generate a fresh subtask when # the action queue is empty (i.e. right before LowLevelForward # refreshes the chunk). ``select_message`` takes ~2 s on MPS, diff --git a/src/lerobot/policies/smolvla2/inference/ui.py b/src/lerobot/policies/smolvla2/inference/ui.py index 692333f21..567610bce 100644 --- a/src/lerobot/policies/smolvla2/inference/ui.py +++ b/src/lerobot/policies/smolvla2/inference/ui.py @@ -91,7 +91,15 @@ def make_state_panel(state: dict[str, Any]) -> Any: (str(len(pending)), "bold magenta"), ) table.add_row("", footer) - return Panel(table, title="[bold]SmolVLA2 state[/]", border_style="cyan") + run_mode = state.get("mode", "action") + mode_tag = ( + "[green]action[/]" if run_mode == "action" else "[yellow]vlm (paused)[/]" + ) + return Panel( + table, + title=f"[bold]SmolVLA2 state[/] · mode: {mode_tag}", + border_style="cyan", + ) def print_user_line(console: Any, line: str) -> None: diff --git a/src/lerobot/policies/smolvla2/inference/vqa.py b/src/lerobot/policies/smolvla2/inference/vqa.py new file mode 100644 index 000000000..3263b33bd --- /dev/null +++ b/src/lerobot/policies/smolvla2/inference/vqa.py @@ -0,0 +1,339 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Interactive VQA for the SmolVLA2 runtime. + +In ``/vlm`` mode a typed line is treated as a VQA question. This module +runs the full interactive flow: + + 1. pull the current observation and list available cameras, + 2. ask the operator which camera to ground the question on, + 3. generate the answer with the VLM conditioned on that one camera, + 4. parse the JSON answer; if it carries a bounding box (``bbox``) or a + point (``keypoint``), draw the overlay on the camera frame, save a + PNG to ``./vqa_overlays/`` and auto-open it. + +VQA answer schemas mirror the annotation pipeline's ``VQA_ANSWER_SHAPES`` +(see ``lerobot.annotations.steerable_pipeline.validator``): + + * ``bbox`` — ``{"detections": [{"label", "bbox_format": "xyxy", + "bbox": [x1, y1, x2, y2]}, ...]}`` + * ``keypoint`` — ``{"label", "point_format": "xy", "point": [x, y]}`` + * ``count`` / ``attribute`` / ``spatial`` — text-only, no overlay. +""" + +from __future__ import annotations + +import json +import logging +import os +import subprocess +import sys +import time +import webbrowser +from pathlib import Path +from typing import Any + +from .runtime_state import push_log + +logger = logging.getLogger(__name__) + +_IMAGE_PREFIX = "observation.images." + +# Iteration order for shape matching — most specific keys first so an +# answer is classified deterministically. +_SHAPE_ORDER = ("bbox", "keypoint", "count", "attribute", "spatial") + +_BBOX_COLOR = (255, 64, 64) +_POINT_COLOR = (64, 220, 64) + + +# --------------------------------------------------------------------------- +# Camera selection +# --------------------------------------------------------------------------- + + +def available_cameras(observation: dict | None) -> list[str]: + """Return the sorted ``observation.images.*`` keys present in ``observation``.""" + if not observation: + return [] + return sorted(k for k in observation if isinstance(k, str) and k.startswith(_IMAGE_PREFIX)) + + +def camera_short_name(camera_key: str) -> str: + """Strip the ``observation.images.`` prefix for display.""" + return camera_key[len(_IMAGE_PREFIX) :] if camera_key.startswith(_IMAGE_PREFIX) else camera_key + + +def prompt_camera_choice( + cameras: list[str], + *, + input_fn: Any = input, + print_fn: Any = print, +) -> str | None: + """Ask the operator which camera to ground a VQA question on. + + Accepts either the menu number or the (short or full) camera name. + A single-camera setup auto-selects without prompting. Returns the + chosen ``observation.images.*`` key, or ``None`` if the operator + cancels / gives an invalid answer. + """ + if not cameras: + return None + if len(cameras) == 1: + return cameras[0] + print_fn("Which camera should I look at?") + for i, cam in enumerate(cameras, 1): + print_fn(f" [{i}] {camera_short_name(cam)}") + try: + raw = str(input_fn("camera> ")).strip() + except (EOFError, KeyboardInterrupt): + return None + if not raw: + return cameras[0] + if raw.isdigit(): + idx = int(raw) - 1 + return cameras[idx] if 0 <= idx < len(cameras) else None + for cam in cameras: + if raw == cam or raw == camera_short_name(cam): + return cam + return None + + +# --------------------------------------------------------------------------- +# Answer parsing +# --------------------------------------------------------------------------- + + +def parse_vqa_answer(answer: str) -> dict | None: + """Parse a VQA answer string into ``{"kind", "payload"}``. + + ``kind`` is one of the ``VQA_ANSWER_SHAPES`` names (``bbox``, + ``keypoint``, ``count``, ``attribute``, ``spatial``) or ``"unknown"`` + when the JSON doesn't match any known shape. Returns ``None`` when + the answer is not parseable JSON / not a JSON object. + """ + if not answer or not answer.strip(): + return None + try: + payload = json.loads(answer) + except (ValueError, TypeError): + return None + if not isinstance(payload, dict): + return None + + try: + from lerobot.annotations.steerable_pipeline.validator import ( # noqa: PLC0415 + VQA_ANSWER_SHAPES, + ) + + shapes = VQA_ANSWER_SHAPES + except ImportError: # pragma: no cover - annotation extra not installed + shapes = { + "bbox": {"detections"}, + "keypoint": {"label", "point_format", "point"}, + "count": {"label", "count"}, + "attribute": {"label", "attribute", "value"}, + "spatial": {"subject", "relation", "object"}, + } + + keys = set(payload) + for kind in _SHAPE_ORDER: + required = shapes.get(kind) + if required and required <= keys: + return {"kind": kind, "payload": payload} + return {"kind": "unknown", "payload": payload} + + +def answer_has_overlay(parsed: dict | None) -> bool: + """True iff ``parsed`` carries drawable spatial coordinates.""" + return bool(parsed) and parsed.get("kind") in ("bbox", "keypoint") + + +# --------------------------------------------------------------------------- +# Overlay drawing +# --------------------------------------------------------------------------- + + +def observation_image_to_pil(image_tensor: Any) -> Any: + """Convert an ``observation.images.*`` tensor to a PIL RGB image. + + The runtime observation stores images as ``(1, C, H, W)`` (or + ``(C, H, W)``) float tensors in ``[0, 1]``. Reuses + ``image_array_to_pil_image`` which handles the CHW→HWC transpose and + the float→uint8 scaling. + """ + from lerobot.datasets.image_writer import image_array_to_pil_image # noqa: PLC0415 + + arr = image_tensor + if hasattr(arr, "detach"): + arr = arr.detach().cpu() + if hasattr(arr, "numpy"): + arr = arr.numpy() + while arr.ndim > 3: # drop leading batch dim(s) + arr = arr[0] + return image_array_to_pil_image(arr).convert("RGB") + + +def draw_vqa_overlay(image: Any, parsed: dict) -> Any: + """Draw ``bbox`` / ``keypoint`` answers onto a copy of ``image``. + + Non-spatial answers (``count`` / ``attribute`` / ``spatial`` / + ``unknown``) are returned as an unmodified copy. + """ + from PIL import ImageDraw # noqa: PLC0415 + + img = image.convert("RGB").copy() + kind = parsed.get("kind") + payload = parsed.get("payload") or {} + draw = ImageDraw.Draw(img) + + if kind == "bbox": + for det in payload.get("detections") or []: + if not isinstance(det, dict): + continue + box = det.get("bbox") + if not (isinstance(box, list | tuple) and len(box) == 4): + continue + try: + x1, y1, x2, y2 = (float(v) for v in box) + except (TypeError, ValueError): + continue + draw.rectangle([x1, y1, x2, y2], outline=_BBOX_COLOR, width=3) + label = str(det.get("label", "")).strip() + if label: + draw.text((x1 + 3, max(0.0, y1 - 12)), label, fill=_BBOX_COLOR) + elif kind == "keypoint": + point = payload.get("point") + if isinstance(point, list | tuple) and len(point) == 2: + try: + x, y = float(point[0]), float(point[1]) + except (TypeError, ValueError): + return img + r = 6 + draw.ellipse([x - r, y - r, x + r, y + r], outline=_POINT_COLOR, width=3) + draw.line([x - 2 * r, y, x + 2 * r, y], fill=_POINT_COLOR, width=2) + draw.line([x, y - 2 * r, x, y + 2 * r], fill=_POINT_COLOR, width=2) + label = str(payload.get("label", "")).strip() + if label: + draw.text((x + r + 3, y - r), label, fill=_POINT_COLOR) + return img + + +def _open_file(path: Path) -> None: + """Best-effort open ``path`` in the OS default viewer.""" + try: + if sys.platform == "darwin": + subprocess.run(["open", str(path)], check=False) + elif sys.platform.startswith("linux"): + subprocess.run(["xdg-open", str(path)], check=False) + elif os.name == "nt": + os.startfile(str(path)) # type: ignore[attr-defined] # noqa: S606 + else: # pragma: no cover - exotic platform + webbrowser.open(path.resolve().as_uri()) + except Exception as exc: # noqa: BLE001 + logger.debug("could not auto-open %s: %s", path, exc) + + +def save_and_open_overlay(image: Any, out_dir: str | Path = "./vqa_overlays") -> Path: + """Save ``image`` as a timestamped PNG under ``out_dir`` and auto-open it.""" + out = Path(out_dir) + out.mkdir(parents=True, exist_ok=True) + path = out / f"vqa_{int(time.time() * 1000)}.png" + image.save(path) + _open_file(path) + return path + + +# --------------------------------------------------------------------------- +# Orchestrator +# --------------------------------------------------------------------------- + + +def handle_vqa_query( + *, + policy: Any, + observation_provider: Any, + question: str, + state: dict[str, Any], + input_fn: Any = input, + print_fn: Any = print, +) -> None: + """Run one interactive VQA question end to end. + + Called synchronously from the input layer while the runtime is in + ``/vlm`` mode (the action loop is gated off, so the policy is not in + concurrent use). All progress is reported via :func:`push_log` so it + shows up in the state panel's scrollback. + """ + from .steps import _generate_with_policy, _msgs_for_vqa # noqa: PLC0415 + + if policy is None or not hasattr(policy, "select_message"): + push_log(state, " [warn] vqa: policy has no select_message — skipping") + return + + observation: dict | None = None + if observation_provider is not None: + try: + observation = observation_provider() + except Exception as exc: # noqa: BLE001 + logger.debug("observation_provider raised %s", exc) + + cameras = available_cameras(observation) + chosen: str | None = None + if cameras: + chosen = prompt_camera_choice(cameras, input_fn=input_fn, print_fn=print_fn) + if chosen is None: + push_log(state, " [info] vqa cancelled — no camera selected") + return + push_log(state, f" vqa camera: {camera_short_name(chosen)}") + else: + push_log(state, " [info] vqa: no camera available — answering text-only") + + # Ground the question on the chosen camera only — filter the + # observation to that one image (+ proprio state) so the VLM + # prefix matches the single-image ``ask_vqa_*`` training recipe. + vqa_obs: dict | None = None + if observation is not None and chosen is not None: + vqa_obs = {chosen: observation[chosen]} + if "observation.state" in observation: + vqa_obs["observation.state"] = observation["observation.state"] + + answer = _generate_with_policy( + policy, + _msgs_for_vqa(question), + observation=vqa_obs, + state=state, + label="vqa gen", + ) + if not answer: + push_log(state, " [info] vqa gen returned empty") + return + push_log(state, f" vqa: {answer}") + + parsed = parse_vqa_answer(answer) + if not answer_has_overlay(parsed): + if parsed is None: + push_log(state, " [info] vqa answer is not JSON — no overlay") + return + if observation is None or chosen is None: + push_log(state, " [info] no camera image — cannot draw overlay") + return + try: + pil = observation_image_to_pil(observation[chosen]) + overlay = draw_vqa_overlay(pil, parsed) + path = save_and_open_overlay(overlay) + push_log(state, f" vqa overlay saved: {path}") + except Exception as exc: # noqa: BLE001 + logger.warning("vqa overlay failed: %s", exc, exc_info=logger.isEnabledFor(logging.DEBUG)) + push_log(state, f" [warn] vqa overlay failed: {type(exc).__name__}: {exc}") diff --git a/src/lerobot/scripts/lerobot_smolvla2_runtime.py b/src/lerobot/scripts/lerobot_smolvla2_runtime.py index 941ba9641..3182386c5 100644 --- a/src/lerobot/scripts/lerobot_smolvla2_runtime.py +++ b/src/lerobot/scripts/lerobot_smolvla2_runtime.py @@ -898,6 +898,126 @@ def _build_robot_action_executor( return _executor +def _dataset_task_strings(ds_meta: Any) -> list[str]: + """Pull the unique task strings from a ``LeRobotDatasetMetadata``. + + ``ds_meta.tasks`` is a pandas DataFrame indexed by the task string; + return the index as a plain list (empty when no dataset / no tasks). + """ + if ds_meta is None: + return [] + tasks = getattr(ds_meta, "tasks", None) + if tasks is None: + return [] + try: + return [str(t) for t in list(tasks.index)] + except Exception: # noqa: BLE001 + return [] + + +def _select_task_interactively(ds_meta: Any, current_task: str | None) -> str | None: + """Prompt the operator to pick a task from the dataset or type one. + + Called at startup when no ``--task`` was given. Non-TTY / scripted + runs return ``current_task`` unchanged so the existing + "first stdin line becomes the task" behaviour is preserved. + """ + if current_task: + return current_task + if not sys.stdin.isatty(): + return current_task + + tasks = _dataset_task_strings(ds_meta) + if not tasks: + try: + typed = input("[smolvla2] Enter the task: ").strip() + except (EOFError, KeyboardInterrupt): + return current_task + return typed or current_task + + print("[smolvla2] Select a task:", flush=True) + for i, task in enumerate(tasks, 1): + print(f" [{i}] {task}", flush=True) + print(" [c] type a custom task", flush=True) + try: + raw = input("task> ").strip() + except (EOFError, KeyboardInterrupt): + return current_task + if not raw: + return tasks[0] + if raw.lower() in {"c", "custom"}: + try: + return input("[smolvla2] Enter the task: ").strip() or current_task + except (EOFError, KeyboardInterrupt): + return current_task + if raw.isdigit(): + idx = int(raw) - 1 + if 0 <= idx < len(tasks): + return tasks[idx] + print("[smolvla2] invalid choice — using the first task", flush=True) + return tasks[0] + # Treat anything else as a custom task string typed directly. + return raw + + +def _print_runtime_help() -> None: + """Print the slash-command reference.""" + print( + "[smolvla2] commands:\n" + " /action run the robot (default mode)\n" + " /vlm pause the action loop; typed lines become VQA questions\n" + " /help show this help\n" + " task: switch task (clears plan / memory / subtask)\n" + " rephrase: reword the task in place\n" + " stop | quit | exit end the session", + flush=True, + ) + + +def _handle_slash_command(runtime: Any, line: str) -> bool: + """Handle ``/action`` / ``/vlm`` / ``/help``. + + Returns ``True`` when ``line`` was a recognised command (and was + consumed), ``False`` otherwise. + """ + cmd = line.strip().lower() + if cmd in {"/action", "/act"}: + runtime.state["mode"] = "action" + print("[smolvla2] mode: action — robot running", flush=True) + return True + if cmd in {"/vlm", "/vqa"}: + runtime.state["mode"] = "vlm" + # Drop any queued chunk so no stale action fires while paused. + queue = runtime.state.get("action_queue") + if hasattr(queue, "clear"): + queue.clear() + print( + "[smolvla2] mode: vlm — action loop paused; type VQA questions", + flush=True, + ) + return True + if cmd in {"/help", "/?"}: + _print_runtime_help() + return True + return False + + +def _run_vqa_query(runtime: Any, question: str) -> None: + """Run one interactive VQA question against the runtime's policy. + + Used by both loops when in ``/vlm`` mode — the action loop is paused + so the policy is free for a synchronous VQA call. + """ + from lerobot.policies.smolvla2.inference.vqa import handle_vqa_query # noqa: PLC0415 + + handle_vqa_query( + policy=runtime.policy, + observation_provider=runtime.observation_provider, + question=question, + state=runtime.state, + ) + + def _run_autonomous( runtime: Any, *, @@ -963,7 +1083,8 @@ def _run_autonomous( ) redraw() print( - " [autonomous] type interjections / '?' questions on stdin, " + " [autonomous] type interjections / '?' questions on stdin; " + "/vlm for VQA mode, /action to resume, /help for commands, " "'stop' or Ctrl+C to quit", flush=True, ) @@ -998,6 +1119,9 @@ def _run_autonomous( lower = line.lower() if lower in {"stop", "quit", "exit"}: break + # Slash commands (/action, /vlm, /help) flip the run mode. + if _handle_slash_command(runtime, line): + continue # ``task: `` always overrides the active task — both # at first set and to switch tasks mid-run. Without the # prefix and with a task already set, an utterance becomes @@ -1036,6 +1160,12 @@ def _run_autonomous( if not runtime.state.get("task"): runtime.set_task(line) continue + # ``/vlm`` mode: the whole line is a VQA question, handled + # synchronously (the action loop is paused so the policy is + # not in concurrent use by the background runtime thread). + if runtime.state.get("mode", "action") == "vlm": + _run_vqa_query(runtime, line) + continue if lower.endswith("?"): runtime.state["recent_vqa_query"] = line runtime.state.setdefault("events_this_tick", []).append("user_vqa_query") @@ -1080,8 +1210,16 @@ def _make_state_panel_renderer( def _redraw(robot_lines: list[str] | None = None) -> None: console.clear() - console.rule(f"[bold]SmolVLA2[/] · {mode_label}", style="cyan") st = runtime.state + run_mode = st.get("mode", "action") + mode_tag = ( + "[green]mode: action[/]" + if run_mode == "action" + else "[yellow]mode: vlm (action loop paused)[/]" + ) + console.rule( + f"[bold]SmolVLA2[/] · {mode_label} · {mode_tag}", style="cyan" + ) for key, label in ( ("task", "task"), ("current_subtask", "subtask"), @@ -1157,8 +1295,9 @@ def _make_state_panel_renderer( console.print() if not st.get("task"): console.print( - " [dim]Type the task to begin. Lines ending in '?' are VQA, " - "anything else is an interjection. Type 'stop' to exit.[/]" + " [dim]Type the task to begin. /vlm switches to VQA mode, " + "/action resumes the robot, /help lists commands. " + "Type 'stop' to exit.[/]" ) return _redraw @@ -1259,6 +1398,12 @@ def main(argv: list[str] | None = None) -> int: flush=True, ) + # No task yet (no --task, no canonical dataset task) — let the + # operator pick one from the dataset's task list or type a custom + # one. Non-TTY runs keep the "first stdin line is the task" path. + if not args.task: + args.task = _select_task_interactively(ds_meta, args.task) + observation_provider: Callable[[], dict | None] | None = None robot_executor: Callable[[Any], None] | None = None robot = None @@ -1415,6 +1560,28 @@ def _run_repl(runtime: Any, *, initial_task: str | None, max_ticks: int | None) if lower in {"stop", "quit", "exit"}: break + # Slash commands (/action, /vlm, /help) flip the run mode. + if _handle_slash_command(runtime, line): + _redraw(last_logs) + continue + + # ``/vlm`` mode: a typed line (that isn't a task command) is + # a VQA question — run it synchronously and skip the action + # pipeline tick entirely. + if ( + runtime.state.get("task") + and runtime.state.get("mode", "action") == "vlm" + and not lower.startswith(("task:", "rephrase:")) + ): + runtime.state["log_lines"] = [] + _run_vqa_query(runtime, line) + last_logs = list(runtime.state.get("log_lines") or []) + _redraw(last_logs) + ticks_done += 1 + if max_ticks is not None and ticks_done >= max_ticks: + break + continue + # Inject the user input as the right kind of event, # then run a single pipeline tick to consume it. if lower.startswith("task:"): diff --git a/tests/policies/smolvla/test_smolvla2_vqa_overlay.py b/tests/policies/smolvla/test_smolvla2_vqa_overlay.py new file mode 100644 index 000000000..c79a6df26 --- /dev/null +++ b/tests/policies/smolvla/test_smolvla2_vqa_overlay.py @@ -0,0 +1,179 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the SmolVLA2 runtime's interactive-VQA helpers. + +Covers camera selection, VQA-answer parsing, and the bounding-box / +keypoint overlay drawing — the pure functions, no model load. +""" + +import numpy as np +import pytest + +from lerobot.policies.smolvla2.inference.vqa import ( + answer_has_overlay, + available_cameras, + camera_short_name, + draw_vqa_overlay, + observation_image_to_pil, + parse_vqa_answer, + prompt_camera_choice, +) + +PIL = pytest.importorskip("PIL") +from PIL import Image # noqa: E402 + +# --------------------------------------------------------------------------- +# Camera selection +# --------------------------------------------------------------------------- + + +def test_available_cameras_extracts_and_sorts_image_keys(): + observation = { + "observation.images.wrist": object(), + "observation.state": object(), + "observation.images.top": object(), + "task": "x", + } + assert available_cameras(observation) == [ + "observation.images.top", + "observation.images.wrist", + ] + + +def test_available_cameras_handles_none_and_empty(): + assert available_cameras(None) == [] + assert available_cameras({}) == [] + + +def test_camera_short_name_strips_prefix(): + assert camera_short_name("observation.images.top") == "top" + assert camera_short_name("top") == "top" + + +def test_prompt_camera_choice_single_camera_auto_selects(): + cams = ["observation.images.top"] + # input_fn must never be called for a single-camera setup. + chosen = prompt_camera_choice(cams, input_fn=_boom, print_fn=lambda *_: None) + assert chosen == "observation.images.top" + + +def test_prompt_camera_choice_by_number(): + cams = ["observation.images.top", "observation.images.wrist"] + chosen = prompt_camera_choice(cams, input_fn=lambda _: "2", print_fn=lambda *_: None) + assert chosen == "observation.images.wrist" + + +def test_prompt_camera_choice_by_name(): + cams = ["observation.images.top", "observation.images.wrist"] + chosen = prompt_camera_choice(cams, input_fn=lambda _: "top", print_fn=lambda *_: None) + assert chosen == "observation.images.top" + + +def test_prompt_camera_choice_invalid_returns_none(): + cams = ["observation.images.top", "observation.images.wrist"] + assert prompt_camera_choice(cams, input_fn=lambda _: "99", print_fn=lambda *_: None) is None + + +def _boom(*_args, **_kwargs): + raise AssertionError("input_fn should not be called") + + +# --------------------------------------------------------------------------- +# Answer parsing +# --------------------------------------------------------------------------- + + +def test_parse_bbox_answer(): + answer = '{"detections": [{"label": "cube", "bbox_format": "xyxy", "bbox": [10, 20, 50, 80]}]}' + parsed = parse_vqa_answer(answer) + assert parsed["kind"] == "bbox" + assert answer_has_overlay(parsed) + + +def test_parse_keypoint_answer(): + answer = '{"label": "blue cube", "point_format": "xy", "point": [120, 90]}' + parsed = parse_vqa_answer(answer) + assert parsed["kind"] == "keypoint" + assert answer_has_overlay(parsed) + + +def test_parse_count_answer_is_not_an_overlay(): + parsed = parse_vqa_answer('{"label": "cubes", "count": 2}') + assert parsed["kind"] == "count" + assert not answer_has_overlay(parsed) + + +def test_parse_invalid_json_returns_none(): + assert parse_vqa_answer("not json at all") is None + assert parse_vqa_answer("") is None + # A JSON array is valid JSON but not a VQA answer object. + assert parse_vqa_answer("[1, 2, 3]") is None + + +def test_parse_unknown_shape(): + parsed = parse_vqa_answer('{"weird": "payload"}') + assert parsed["kind"] == "unknown" + assert not answer_has_overlay(parsed) + + +# --------------------------------------------------------------------------- +# Overlay drawing +# --------------------------------------------------------------------------- + + +def _blank(size=(160, 120)): + return Image.new("RGB", size, (0, 0, 0)) + + +def test_draw_bbox_overlay_changes_pixels_and_preserves_size(): + img = _blank() + parsed = parse_vqa_answer( + '{"detections": [{"label": "cube", "bbox_format": "xyxy", "bbox": [10, 20, 50, 80]}]}' + ) + out = draw_vqa_overlay(img, parsed) + assert out.size == img.size + assert out.tobytes() != img.tobytes() + + +def test_draw_keypoint_overlay_changes_pixels(): + img = _blank() + parsed = parse_vqa_answer('{"label": "cube", "point_format": "xy", "point": [80, 60]}') + out = draw_vqa_overlay(img, parsed) + assert out.size == img.size + assert out.tobytes() != img.tobytes() + + +def test_draw_overlay_non_spatial_leaves_image_unchanged(): + img = _blank() + parsed = parse_vqa_answer('{"label": "cubes", "count": 2}') + out = draw_vqa_overlay(img, parsed) + assert out.tobytes() == img.tobytes() + + +def test_draw_overlay_tolerates_malformed_coordinates(): + img = _blank() + # bbox with the wrong arity must not raise. + out = draw_vqa_overlay(img, {"kind": "bbox", "payload": {"detections": [{"bbox": [1, 2]}]}}) + assert out.size == img.size + + +def test_observation_image_to_pil_from_batched_float_array(): + # (1, C, H, W) float array in [0, 1], the runtime observation shape. + arr = np.zeros((1, 3, 24, 32), dtype=np.float32) + pil = observation_image_to_pil(arr) + assert pil.size == (32, 24) + assert pil.mode == "RGB"