mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 04:59:47 +00:00
feat(smolvla2): startup task picker, /vlm mode toggle, interactive VQA overlay
Three additions to the SmolVLA2 interactive runtime: 1. Startup task picker — when no --task is given, the runtime lists the dataset's task strings as a numbered menu (plus a custom-task option) instead of silently waiting for the first stdin line. 2. Mode toggle — /action and /vlm slash commands flip a persistent run mode. /vlm pauses the whole action loop (HighLevelSubtaskFwd, LowLevelForward and DispatchAction gate on state["mode"]) and clears the action queue so the robot holds position; /action resumes it. The mode is shown in the state panel. 3. Interactive VQA — in /vlm mode a typed line is a VQA question. The new inference/vqa.py module asks which camera to ground on, runs the VLM on that single camera, and when the answer is a bbox/keypoint it draws the overlay, saves a PNG to ./vqa_overlays/ and auto-opens it. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -16,11 +16,16 @@
|
||||
Reads non-blocking stdin lines, classifies each one heuristically:
|
||||
|
||||
"stop" / "quit" / "exit" → state["stop"] = True
|
||||
"/action" / "/vlm" → set state["mode"]
|
||||
ends with "?" → user_vqa_query event
|
||||
starts with "task:" or first line → set runtime task
|
||||
anything else → user_interjection event
|
||||
|
||||
Plugged into the runtime via ``event_collector=StdinReader().poll``.
|
||||
|
||||
Note: the shipped CLI (``lerobot-smolvla2-runtime``) drives stdin
|
||||
directly in its REPL / autonomous loops and does *not* wire this
|
||||
collector; it's kept as the documented embedding hook and for tests.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -70,6 +75,19 @@ class StdinReader:
|
||||
state["stop"] = True
|
||||
return
|
||||
|
||||
# Slash commands flip the run mode. ``/vlm`` pauses the action
|
||||
# loop (the action steps gate on ``state["mode"]``); ``/action``
|
||||
# resumes it.
|
||||
if lower in {"/action", "/act"}:
|
||||
state["mode"] = "action"
|
||||
return
|
||||
if lower in {"/vlm", "/vqa"}:
|
||||
state["mode"] = "vlm"
|
||||
queue = state.get("action_queue")
|
||||
if hasattr(queue, "clear"):
|
||||
queue.clear()
|
||||
return
|
||||
|
||||
# First non-control line sets the task if no task is active.
|
||||
if not state.get("task"):
|
||||
task = line[5:].strip() if lower.startswith("task:") else line
|
||||
|
||||
@@ -33,6 +33,9 @@ Stable keys (read by multiple steps):
|
||||
events_this_tick list[str] triggers consumed this tick
|
||||
_tick Tick current tick (set by the loop)
|
||||
|
||||
mode str "action" (run the robot) | "vlm" (VQA only,
|
||||
action loop paused)
|
||||
|
||||
log_lines list[str] human-readable status lines printed each tick
|
||||
"""
|
||||
|
||||
@@ -54,6 +57,7 @@ def initial_runtime_state(task: str | None = None) -> dict[str, Any]:
|
||||
"tool_calls_pending": [],
|
||||
"events_this_tick": [],
|
||||
"log_lines": [],
|
||||
"mode": "action",
|
||||
"stop": False,
|
||||
}
|
||||
|
||||
|
||||
@@ -91,6 +91,10 @@ class LowLevelForward(InferenceStep):
|
||||
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||
if self.policy is None or self.observation_provider is None:
|
||||
return None
|
||||
# ``/vlm`` mode pauses the whole action loop so the robot holds
|
||||
# position while the operator probes the VLM with VQA.
|
||||
if state.get("mode", "action") != "action":
|
||||
return None
|
||||
if not state.get("task"):
|
||||
return None
|
||||
|
||||
@@ -197,6 +201,12 @@ class DispatchAction(InferenceStep):
|
||||
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||
import time as _time # noqa: PLC0415
|
||||
|
||||
# ``/vlm`` mode pauses dispatch — the robot holds its last
|
||||
# commanded position while the operator runs VQA.
|
||||
if state.get("mode", "action") != "action":
|
||||
self._last_dispatch_t = None
|
||||
return None
|
||||
|
||||
queue = state.get("action_queue")
|
||||
if not queue:
|
||||
# Reset wall-clock anchor when the queue is empty so the
|
||||
@@ -366,6 +376,10 @@ class HighLevelSubtaskFwd(InferenceStep):
|
||||
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||
if self.policy is None or not state.get("task"):
|
||||
return None
|
||||
# ``/vlm`` mode pauses subtask generation along with the rest of
|
||||
# the action loop.
|
||||
if state.get("mode", "action") != "action":
|
||||
return None
|
||||
# Gate to chunk boundaries: only generate a fresh subtask when
|
||||
# the action queue is empty (i.e. right before LowLevelForward
|
||||
# refreshes the chunk). ``select_message`` takes ~2 s on MPS,
|
||||
|
||||
@@ -91,7 +91,15 @@ def make_state_panel(state: dict[str, Any]) -> Any:
|
||||
(str(len(pending)), "bold magenta"),
|
||||
)
|
||||
table.add_row("", footer)
|
||||
return Panel(table, title="[bold]SmolVLA2 state[/]", border_style="cyan")
|
||||
run_mode = state.get("mode", "action")
|
||||
mode_tag = (
|
||||
"[green]action[/]" if run_mode == "action" else "[yellow]vlm (paused)[/]"
|
||||
)
|
||||
return Panel(
|
||||
table,
|
||||
title=f"[bold]SmolVLA2 state[/] · mode: {mode_tag}",
|
||||
border_style="cyan",
|
||||
)
|
||||
|
||||
|
||||
def print_user_line(console: Any, line: str) -> None:
|
||||
|
||||
@@ -0,0 +1,339 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Interactive VQA for the SmolVLA2 runtime.
|
||||
|
||||
In ``/vlm`` mode a typed line is treated as a VQA question. This module
|
||||
runs the full interactive flow:
|
||||
|
||||
1. pull the current observation and list available cameras,
|
||||
2. ask the operator which camera to ground the question on,
|
||||
3. generate the answer with the VLM conditioned on that one camera,
|
||||
4. parse the JSON answer; if it carries a bounding box (``bbox``) or a
|
||||
point (``keypoint``), draw the overlay on the camera frame, save a
|
||||
PNG to ``./vqa_overlays/`` and auto-open it.
|
||||
|
||||
VQA answer schemas mirror the annotation pipeline's ``VQA_ANSWER_SHAPES``
|
||||
(see ``lerobot.annotations.steerable_pipeline.validator``):
|
||||
|
||||
* ``bbox`` — ``{"detections": [{"label", "bbox_format": "xyxy",
|
||||
"bbox": [x1, y1, x2, y2]}, ...]}``
|
||||
* ``keypoint`` — ``{"label", "point_format": "xy", "point": [x, y]}``
|
||||
* ``count`` / ``attribute`` / ``spatial`` — text-only, no overlay.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import webbrowser
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from .runtime_state import push_log
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_IMAGE_PREFIX = "observation.images."
|
||||
|
||||
# Iteration order for shape matching — most specific keys first so an
|
||||
# answer is classified deterministically.
|
||||
_SHAPE_ORDER = ("bbox", "keypoint", "count", "attribute", "spatial")
|
||||
|
||||
_BBOX_COLOR = (255, 64, 64)
|
||||
_POINT_COLOR = (64, 220, 64)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Camera selection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def available_cameras(observation: dict | None) -> list[str]:
|
||||
"""Return the sorted ``observation.images.*`` keys present in ``observation``."""
|
||||
if not observation:
|
||||
return []
|
||||
return sorted(k for k in observation if isinstance(k, str) and k.startswith(_IMAGE_PREFIX))
|
||||
|
||||
|
||||
def camera_short_name(camera_key: str) -> str:
|
||||
"""Strip the ``observation.images.`` prefix for display."""
|
||||
return camera_key[len(_IMAGE_PREFIX) :] if camera_key.startswith(_IMAGE_PREFIX) else camera_key
|
||||
|
||||
|
||||
def prompt_camera_choice(
|
||||
cameras: list[str],
|
||||
*,
|
||||
input_fn: Any = input,
|
||||
print_fn: Any = print,
|
||||
) -> str | None:
|
||||
"""Ask the operator which camera to ground a VQA question on.
|
||||
|
||||
Accepts either the menu number or the (short or full) camera name.
|
||||
A single-camera setup auto-selects without prompting. Returns the
|
||||
chosen ``observation.images.*`` key, or ``None`` if the operator
|
||||
cancels / gives an invalid answer.
|
||||
"""
|
||||
if not cameras:
|
||||
return None
|
||||
if len(cameras) == 1:
|
||||
return cameras[0]
|
||||
print_fn("Which camera should I look at?")
|
||||
for i, cam in enumerate(cameras, 1):
|
||||
print_fn(f" [{i}] {camera_short_name(cam)}")
|
||||
try:
|
||||
raw = str(input_fn("camera> ")).strip()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
return None
|
||||
if not raw:
|
||||
return cameras[0]
|
||||
if raw.isdigit():
|
||||
idx = int(raw) - 1
|
||||
return cameras[idx] if 0 <= idx < len(cameras) else None
|
||||
for cam in cameras:
|
||||
if raw == cam or raw == camera_short_name(cam):
|
||||
return cam
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Answer parsing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def parse_vqa_answer(answer: str) -> dict | None:
|
||||
"""Parse a VQA answer string into ``{"kind", "payload"}``.
|
||||
|
||||
``kind`` is one of the ``VQA_ANSWER_SHAPES`` names (``bbox``,
|
||||
``keypoint``, ``count``, ``attribute``, ``spatial``) or ``"unknown"``
|
||||
when the JSON doesn't match any known shape. Returns ``None`` when
|
||||
the answer is not parseable JSON / not a JSON object.
|
||||
"""
|
||||
if not answer or not answer.strip():
|
||||
return None
|
||||
try:
|
||||
payload = json.loads(answer)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
if not isinstance(payload, dict):
|
||||
return None
|
||||
|
||||
try:
|
||||
from lerobot.annotations.steerable_pipeline.validator import ( # noqa: PLC0415
|
||||
VQA_ANSWER_SHAPES,
|
||||
)
|
||||
|
||||
shapes = VQA_ANSWER_SHAPES
|
||||
except ImportError: # pragma: no cover - annotation extra not installed
|
||||
shapes = {
|
||||
"bbox": {"detections"},
|
||||
"keypoint": {"label", "point_format", "point"},
|
||||
"count": {"label", "count"},
|
||||
"attribute": {"label", "attribute", "value"},
|
||||
"spatial": {"subject", "relation", "object"},
|
||||
}
|
||||
|
||||
keys = set(payload)
|
||||
for kind in _SHAPE_ORDER:
|
||||
required = shapes.get(kind)
|
||||
if required and required <= keys:
|
||||
return {"kind": kind, "payload": payload}
|
||||
return {"kind": "unknown", "payload": payload}
|
||||
|
||||
|
||||
def answer_has_overlay(parsed: dict | None) -> bool:
|
||||
"""True iff ``parsed`` carries drawable spatial coordinates."""
|
||||
return bool(parsed) and parsed.get("kind") in ("bbox", "keypoint")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Overlay drawing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def observation_image_to_pil(image_tensor: Any) -> Any:
|
||||
"""Convert an ``observation.images.*`` tensor to a PIL RGB image.
|
||||
|
||||
The runtime observation stores images as ``(1, C, H, W)`` (or
|
||||
``(C, H, W)``) float tensors in ``[0, 1]``. Reuses
|
||||
``image_array_to_pil_image`` which handles the CHW→HWC transpose and
|
||||
the float→uint8 scaling.
|
||||
"""
|
||||
from lerobot.datasets.image_writer import image_array_to_pil_image # noqa: PLC0415
|
||||
|
||||
arr = image_tensor
|
||||
if hasattr(arr, "detach"):
|
||||
arr = arr.detach().cpu()
|
||||
if hasattr(arr, "numpy"):
|
||||
arr = arr.numpy()
|
||||
while arr.ndim > 3: # drop leading batch dim(s)
|
||||
arr = arr[0]
|
||||
return image_array_to_pil_image(arr).convert("RGB")
|
||||
|
||||
|
||||
def draw_vqa_overlay(image: Any, parsed: dict) -> Any:
|
||||
"""Draw ``bbox`` / ``keypoint`` answers onto a copy of ``image``.
|
||||
|
||||
Non-spatial answers (``count`` / ``attribute`` / ``spatial`` /
|
||||
``unknown``) are returned as an unmodified copy.
|
||||
"""
|
||||
from PIL import ImageDraw # noqa: PLC0415
|
||||
|
||||
img = image.convert("RGB").copy()
|
||||
kind = parsed.get("kind")
|
||||
payload = parsed.get("payload") or {}
|
||||
draw = ImageDraw.Draw(img)
|
||||
|
||||
if kind == "bbox":
|
||||
for det in payload.get("detections") or []:
|
||||
if not isinstance(det, dict):
|
||||
continue
|
||||
box = det.get("bbox")
|
||||
if not (isinstance(box, list | tuple) and len(box) == 4):
|
||||
continue
|
||||
try:
|
||||
x1, y1, x2, y2 = (float(v) for v in box)
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
draw.rectangle([x1, y1, x2, y2], outline=_BBOX_COLOR, width=3)
|
||||
label = str(det.get("label", "")).strip()
|
||||
if label:
|
||||
draw.text((x1 + 3, max(0.0, y1 - 12)), label, fill=_BBOX_COLOR)
|
||||
elif kind == "keypoint":
|
||||
point = payload.get("point")
|
||||
if isinstance(point, list | tuple) and len(point) == 2:
|
||||
try:
|
||||
x, y = float(point[0]), float(point[1])
|
||||
except (TypeError, ValueError):
|
||||
return img
|
||||
r = 6
|
||||
draw.ellipse([x - r, y - r, x + r, y + r], outline=_POINT_COLOR, width=3)
|
||||
draw.line([x - 2 * r, y, x + 2 * r, y], fill=_POINT_COLOR, width=2)
|
||||
draw.line([x, y - 2 * r, x, y + 2 * r], fill=_POINT_COLOR, width=2)
|
||||
label = str(payload.get("label", "")).strip()
|
||||
if label:
|
||||
draw.text((x + r + 3, y - r), label, fill=_POINT_COLOR)
|
||||
return img
|
||||
|
||||
|
||||
def _open_file(path: Path) -> None:
|
||||
"""Best-effort open ``path`` in the OS default viewer."""
|
||||
try:
|
||||
if sys.platform == "darwin":
|
||||
subprocess.run(["open", str(path)], check=False)
|
||||
elif sys.platform.startswith("linux"):
|
||||
subprocess.run(["xdg-open", str(path)], check=False)
|
||||
elif os.name == "nt":
|
||||
os.startfile(str(path)) # type: ignore[attr-defined] # noqa: S606
|
||||
else: # pragma: no cover - exotic platform
|
||||
webbrowser.open(path.resolve().as_uri())
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.debug("could not auto-open %s: %s", path, exc)
|
||||
|
||||
|
||||
def save_and_open_overlay(image: Any, out_dir: str | Path = "./vqa_overlays") -> Path:
|
||||
"""Save ``image`` as a timestamped PNG under ``out_dir`` and auto-open it."""
|
||||
out = Path(out_dir)
|
||||
out.mkdir(parents=True, exist_ok=True)
|
||||
path = out / f"vqa_{int(time.time() * 1000)}.png"
|
||||
image.save(path)
|
||||
_open_file(path)
|
||||
return path
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Orchestrator
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def handle_vqa_query(
|
||||
*,
|
||||
policy: Any,
|
||||
observation_provider: Any,
|
||||
question: str,
|
||||
state: dict[str, Any],
|
||||
input_fn: Any = input,
|
||||
print_fn: Any = print,
|
||||
) -> None:
|
||||
"""Run one interactive VQA question end to end.
|
||||
|
||||
Called synchronously from the input layer while the runtime is in
|
||||
``/vlm`` mode (the action loop is gated off, so the policy is not in
|
||||
concurrent use). All progress is reported via :func:`push_log` so it
|
||||
shows up in the state panel's scrollback.
|
||||
"""
|
||||
from .steps import _generate_with_policy, _msgs_for_vqa # noqa: PLC0415
|
||||
|
||||
if policy is None or not hasattr(policy, "select_message"):
|
||||
push_log(state, " [warn] vqa: policy has no select_message — skipping")
|
||||
return
|
||||
|
||||
observation: dict | None = None
|
||||
if observation_provider is not None:
|
||||
try:
|
||||
observation = observation_provider()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.debug("observation_provider raised %s", exc)
|
||||
|
||||
cameras = available_cameras(observation)
|
||||
chosen: str | None = None
|
||||
if cameras:
|
||||
chosen = prompt_camera_choice(cameras, input_fn=input_fn, print_fn=print_fn)
|
||||
if chosen is None:
|
||||
push_log(state, " [info] vqa cancelled — no camera selected")
|
||||
return
|
||||
push_log(state, f" vqa camera: {camera_short_name(chosen)}")
|
||||
else:
|
||||
push_log(state, " [info] vqa: no camera available — answering text-only")
|
||||
|
||||
# Ground the question on the chosen camera only — filter the
|
||||
# observation to that one image (+ proprio state) so the VLM
|
||||
# prefix matches the single-image ``ask_vqa_*`` training recipe.
|
||||
vqa_obs: dict | None = None
|
||||
if observation is not None and chosen is not None:
|
||||
vqa_obs = {chosen: observation[chosen]}
|
||||
if "observation.state" in observation:
|
||||
vqa_obs["observation.state"] = observation["observation.state"]
|
||||
|
||||
answer = _generate_with_policy(
|
||||
policy,
|
||||
_msgs_for_vqa(question),
|
||||
observation=vqa_obs,
|
||||
state=state,
|
||||
label="vqa gen",
|
||||
)
|
||||
if not answer:
|
||||
push_log(state, " [info] vqa gen returned empty")
|
||||
return
|
||||
push_log(state, f" vqa: {answer}")
|
||||
|
||||
parsed = parse_vqa_answer(answer)
|
||||
if not answer_has_overlay(parsed):
|
||||
if parsed is None:
|
||||
push_log(state, " [info] vqa answer is not JSON — no overlay")
|
||||
return
|
||||
if observation is None or chosen is None:
|
||||
push_log(state, " [info] no camera image — cannot draw overlay")
|
||||
return
|
||||
try:
|
||||
pil = observation_image_to_pil(observation[chosen])
|
||||
overlay = draw_vqa_overlay(pil, parsed)
|
||||
path = save_and_open_overlay(overlay)
|
||||
push_log(state, f" vqa overlay saved: {path}")
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("vqa overlay failed: %s", exc, exc_info=logger.isEnabledFor(logging.DEBUG))
|
||||
push_log(state, f" [warn] vqa overlay failed: {type(exc).__name__}: {exc}")
|
||||
@@ -898,6 +898,126 @@ def _build_robot_action_executor(
|
||||
return _executor
|
||||
|
||||
|
||||
def _dataset_task_strings(ds_meta: Any) -> list[str]:
|
||||
"""Pull the unique task strings from a ``LeRobotDatasetMetadata``.
|
||||
|
||||
``ds_meta.tasks`` is a pandas DataFrame indexed by the task string;
|
||||
return the index as a plain list (empty when no dataset / no tasks).
|
||||
"""
|
||||
if ds_meta is None:
|
||||
return []
|
||||
tasks = getattr(ds_meta, "tasks", None)
|
||||
if tasks is None:
|
||||
return []
|
||||
try:
|
||||
return [str(t) for t in list(tasks.index)]
|
||||
except Exception: # noqa: BLE001
|
||||
return []
|
||||
|
||||
|
||||
def _select_task_interactively(ds_meta: Any, current_task: str | None) -> str | None:
|
||||
"""Prompt the operator to pick a task from the dataset or type one.
|
||||
|
||||
Called at startup when no ``--task`` was given. Non-TTY / scripted
|
||||
runs return ``current_task`` unchanged so the existing
|
||||
"first stdin line becomes the task" behaviour is preserved.
|
||||
"""
|
||||
if current_task:
|
||||
return current_task
|
||||
if not sys.stdin.isatty():
|
||||
return current_task
|
||||
|
||||
tasks = _dataset_task_strings(ds_meta)
|
||||
if not tasks:
|
||||
try:
|
||||
typed = input("[smolvla2] Enter the task: ").strip()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
return current_task
|
||||
return typed or current_task
|
||||
|
||||
print("[smolvla2] Select a task:", flush=True)
|
||||
for i, task in enumerate(tasks, 1):
|
||||
print(f" [{i}] {task}", flush=True)
|
||||
print(" [c] type a custom task", flush=True)
|
||||
try:
|
||||
raw = input("task> ").strip()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
return current_task
|
||||
if not raw:
|
||||
return tasks[0]
|
||||
if raw.lower() in {"c", "custom"}:
|
||||
try:
|
||||
return input("[smolvla2] Enter the task: ").strip() or current_task
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
return current_task
|
||||
if raw.isdigit():
|
||||
idx = int(raw) - 1
|
||||
if 0 <= idx < len(tasks):
|
||||
return tasks[idx]
|
||||
print("[smolvla2] invalid choice — using the first task", flush=True)
|
||||
return tasks[0]
|
||||
# Treat anything else as a custom task string typed directly.
|
||||
return raw
|
||||
|
||||
|
||||
def _print_runtime_help() -> None:
|
||||
"""Print the slash-command reference."""
|
||||
print(
|
||||
"[smolvla2] commands:\n"
|
||||
" /action run the robot (default mode)\n"
|
||||
" /vlm pause the action loop; typed lines become VQA questions\n"
|
||||
" /help show this help\n"
|
||||
" task: <text> switch task (clears plan / memory / subtask)\n"
|
||||
" rephrase: <text> reword the task in place\n"
|
||||
" stop | quit | exit end the session",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
|
||||
def _handle_slash_command(runtime: Any, line: str) -> bool:
|
||||
"""Handle ``/action`` / ``/vlm`` / ``/help``.
|
||||
|
||||
Returns ``True`` when ``line`` was a recognised command (and was
|
||||
consumed), ``False`` otherwise.
|
||||
"""
|
||||
cmd = line.strip().lower()
|
||||
if cmd in {"/action", "/act"}:
|
||||
runtime.state["mode"] = "action"
|
||||
print("[smolvla2] mode: action — robot running", flush=True)
|
||||
return True
|
||||
if cmd in {"/vlm", "/vqa"}:
|
||||
runtime.state["mode"] = "vlm"
|
||||
# Drop any queued chunk so no stale action fires while paused.
|
||||
queue = runtime.state.get("action_queue")
|
||||
if hasattr(queue, "clear"):
|
||||
queue.clear()
|
||||
print(
|
||||
"[smolvla2] mode: vlm — action loop paused; type VQA questions",
|
||||
flush=True,
|
||||
)
|
||||
return True
|
||||
if cmd in {"/help", "/?"}:
|
||||
_print_runtime_help()
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _run_vqa_query(runtime: Any, question: str) -> None:
|
||||
"""Run one interactive VQA question against the runtime's policy.
|
||||
|
||||
Used by both loops when in ``/vlm`` mode — the action loop is paused
|
||||
so the policy is free for a synchronous VQA call.
|
||||
"""
|
||||
from lerobot.policies.smolvla2.inference.vqa import handle_vqa_query # noqa: PLC0415
|
||||
|
||||
handle_vqa_query(
|
||||
policy=runtime.policy,
|
||||
observation_provider=runtime.observation_provider,
|
||||
question=question,
|
||||
state=runtime.state,
|
||||
)
|
||||
|
||||
|
||||
def _run_autonomous(
|
||||
runtime: Any,
|
||||
*,
|
||||
@@ -963,7 +1083,8 @@ def _run_autonomous(
|
||||
)
|
||||
redraw()
|
||||
print(
|
||||
" [autonomous] type interjections / '?' questions on stdin, "
|
||||
" [autonomous] type interjections / '?' questions on stdin; "
|
||||
"/vlm for VQA mode, /action to resume, /help for commands, "
|
||||
"'stop' or Ctrl+C to quit",
|
||||
flush=True,
|
||||
)
|
||||
@@ -998,6 +1119,9 @@ def _run_autonomous(
|
||||
lower = line.lower()
|
||||
if lower in {"stop", "quit", "exit"}:
|
||||
break
|
||||
# Slash commands (/action, /vlm, /help) flip the run mode.
|
||||
if _handle_slash_command(runtime, line):
|
||||
continue
|
||||
# ``task: <text>`` always overrides the active task — both
|
||||
# at first set and to switch tasks mid-run. Without the
|
||||
# prefix and with a task already set, an utterance becomes
|
||||
@@ -1036,6 +1160,12 @@ def _run_autonomous(
|
||||
if not runtime.state.get("task"):
|
||||
runtime.set_task(line)
|
||||
continue
|
||||
# ``/vlm`` mode: the whole line is a VQA question, handled
|
||||
# synchronously (the action loop is paused so the policy is
|
||||
# not in concurrent use by the background runtime thread).
|
||||
if runtime.state.get("mode", "action") == "vlm":
|
||||
_run_vqa_query(runtime, line)
|
||||
continue
|
||||
if lower.endswith("?"):
|
||||
runtime.state["recent_vqa_query"] = line
|
||||
runtime.state.setdefault("events_this_tick", []).append("user_vqa_query")
|
||||
@@ -1080,8 +1210,16 @@ def _make_state_panel_renderer(
|
||||
|
||||
def _redraw(robot_lines: list[str] | None = None) -> None:
|
||||
console.clear()
|
||||
console.rule(f"[bold]SmolVLA2[/] · {mode_label}", style="cyan")
|
||||
st = runtime.state
|
||||
run_mode = st.get("mode", "action")
|
||||
mode_tag = (
|
||||
"[green]mode: action[/]"
|
||||
if run_mode == "action"
|
||||
else "[yellow]mode: vlm (action loop paused)[/]"
|
||||
)
|
||||
console.rule(
|
||||
f"[bold]SmolVLA2[/] · {mode_label} · {mode_tag}", style="cyan"
|
||||
)
|
||||
for key, label in (
|
||||
("task", "task"),
|
||||
("current_subtask", "subtask"),
|
||||
@@ -1157,8 +1295,9 @@ def _make_state_panel_renderer(
|
||||
console.print()
|
||||
if not st.get("task"):
|
||||
console.print(
|
||||
" [dim]Type the task to begin. Lines ending in '?' are VQA, "
|
||||
"anything else is an interjection. Type 'stop' to exit.[/]"
|
||||
" [dim]Type the task to begin. /vlm switches to VQA mode, "
|
||||
"/action resumes the robot, /help lists commands. "
|
||||
"Type 'stop' to exit.[/]"
|
||||
)
|
||||
|
||||
return _redraw
|
||||
@@ -1259,6 +1398,12 @@ def main(argv: list[str] | None = None) -> int:
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# No task yet (no --task, no canonical dataset task) — let the
|
||||
# operator pick one from the dataset's task list or type a custom
|
||||
# one. Non-TTY runs keep the "first stdin line is the task" path.
|
||||
if not args.task:
|
||||
args.task = _select_task_interactively(ds_meta, args.task)
|
||||
|
||||
observation_provider: Callable[[], dict | None] | None = None
|
||||
robot_executor: Callable[[Any], None] | None = None
|
||||
robot = None
|
||||
@@ -1415,6 +1560,28 @@ def _run_repl(runtime: Any, *, initial_task: str | None, max_ticks: int | None)
|
||||
if lower in {"stop", "quit", "exit"}:
|
||||
break
|
||||
|
||||
# Slash commands (/action, /vlm, /help) flip the run mode.
|
||||
if _handle_slash_command(runtime, line):
|
||||
_redraw(last_logs)
|
||||
continue
|
||||
|
||||
# ``/vlm`` mode: a typed line (that isn't a task command) is
|
||||
# a VQA question — run it synchronously and skip the action
|
||||
# pipeline tick entirely.
|
||||
if (
|
||||
runtime.state.get("task")
|
||||
and runtime.state.get("mode", "action") == "vlm"
|
||||
and not lower.startswith(("task:", "rephrase:"))
|
||||
):
|
||||
runtime.state["log_lines"] = []
|
||||
_run_vqa_query(runtime, line)
|
||||
last_logs = list(runtime.state.get("log_lines") or [])
|
||||
_redraw(last_logs)
|
||||
ticks_done += 1
|
||||
if max_ticks is not None and ticks_done >= max_ticks:
|
||||
break
|
||||
continue
|
||||
|
||||
# Inject the user input as the right kind of event,
|
||||
# then run a single pipeline tick to consume it.
|
||||
if lower.startswith("task:"):
|
||||
|
||||
@@ -0,0 +1,179 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for the SmolVLA2 runtime's interactive-VQA helpers.
|
||||
|
||||
Covers camera selection, VQA-answer parsing, and the bounding-box /
|
||||
keypoint overlay drawing — the pure functions, no model load.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from lerobot.policies.smolvla2.inference.vqa import (
|
||||
answer_has_overlay,
|
||||
available_cameras,
|
||||
camera_short_name,
|
||||
draw_vqa_overlay,
|
||||
observation_image_to_pil,
|
||||
parse_vqa_answer,
|
||||
prompt_camera_choice,
|
||||
)
|
||||
|
||||
PIL = pytest.importorskip("PIL")
|
||||
from PIL import Image # noqa: E402
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Camera selection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_available_cameras_extracts_and_sorts_image_keys():
|
||||
observation = {
|
||||
"observation.images.wrist": object(),
|
||||
"observation.state": object(),
|
||||
"observation.images.top": object(),
|
||||
"task": "x",
|
||||
}
|
||||
assert available_cameras(observation) == [
|
||||
"observation.images.top",
|
||||
"observation.images.wrist",
|
||||
]
|
||||
|
||||
|
||||
def test_available_cameras_handles_none_and_empty():
|
||||
assert available_cameras(None) == []
|
||||
assert available_cameras({}) == []
|
||||
|
||||
|
||||
def test_camera_short_name_strips_prefix():
|
||||
assert camera_short_name("observation.images.top") == "top"
|
||||
assert camera_short_name("top") == "top"
|
||||
|
||||
|
||||
def test_prompt_camera_choice_single_camera_auto_selects():
|
||||
cams = ["observation.images.top"]
|
||||
# input_fn must never be called for a single-camera setup.
|
||||
chosen = prompt_camera_choice(cams, input_fn=_boom, print_fn=lambda *_: None)
|
||||
assert chosen == "observation.images.top"
|
||||
|
||||
|
||||
def test_prompt_camera_choice_by_number():
|
||||
cams = ["observation.images.top", "observation.images.wrist"]
|
||||
chosen = prompt_camera_choice(cams, input_fn=lambda _: "2", print_fn=lambda *_: None)
|
||||
assert chosen == "observation.images.wrist"
|
||||
|
||||
|
||||
def test_prompt_camera_choice_by_name():
|
||||
cams = ["observation.images.top", "observation.images.wrist"]
|
||||
chosen = prompt_camera_choice(cams, input_fn=lambda _: "top", print_fn=lambda *_: None)
|
||||
assert chosen == "observation.images.top"
|
||||
|
||||
|
||||
def test_prompt_camera_choice_invalid_returns_none():
|
||||
cams = ["observation.images.top", "observation.images.wrist"]
|
||||
assert prompt_camera_choice(cams, input_fn=lambda _: "99", print_fn=lambda *_: None) is None
|
||||
|
||||
|
||||
def _boom(*_args, **_kwargs):
|
||||
raise AssertionError("input_fn should not be called")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Answer parsing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_parse_bbox_answer():
|
||||
answer = '{"detections": [{"label": "cube", "bbox_format": "xyxy", "bbox": [10, 20, 50, 80]}]}'
|
||||
parsed = parse_vqa_answer(answer)
|
||||
assert parsed["kind"] == "bbox"
|
||||
assert answer_has_overlay(parsed)
|
||||
|
||||
|
||||
def test_parse_keypoint_answer():
|
||||
answer = '{"label": "blue cube", "point_format": "xy", "point": [120, 90]}'
|
||||
parsed = parse_vqa_answer(answer)
|
||||
assert parsed["kind"] == "keypoint"
|
||||
assert answer_has_overlay(parsed)
|
||||
|
||||
|
||||
def test_parse_count_answer_is_not_an_overlay():
|
||||
parsed = parse_vqa_answer('{"label": "cubes", "count": 2}')
|
||||
assert parsed["kind"] == "count"
|
||||
assert not answer_has_overlay(parsed)
|
||||
|
||||
|
||||
def test_parse_invalid_json_returns_none():
|
||||
assert parse_vqa_answer("not json at all") is None
|
||||
assert parse_vqa_answer("") is None
|
||||
# A JSON array is valid JSON but not a VQA answer object.
|
||||
assert parse_vqa_answer("[1, 2, 3]") is None
|
||||
|
||||
|
||||
def test_parse_unknown_shape():
|
||||
parsed = parse_vqa_answer('{"weird": "payload"}')
|
||||
assert parsed["kind"] == "unknown"
|
||||
assert not answer_has_overlay(parsed)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Overlay drawing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _blank(size=(160, 120)):
|
||||
return Image.new("RGB", size, (0, 0, 0))
|
||||
|
||||
|
||||
def test_draw_bbox_overlay_changes_pixels_and_preserves_size():
|
||||
img = _blank()
|
||||
parsed = parse_vqa_answer(
|
||||
'{"detections": [{"label": "cube", "bbox_format": "xyxy", "bbox": [10, 20, 50, 80]}]}'
|
||||
)
|
||||
out = draw_vqa_overlay(img, parsed)
|
||||
assert out.size == img.size
|
||||
assert out.tobytes() != img.tobytes()
|
||||
|
||||
|
||||
def test_draw_keypoint_overlay_changes_pixels():
|
||||
img = _blank()
|
||||
parsed = parse_vqa_answer('{"label": "cube", "point_format": "xy", "point": [80, 60]}')
|
||||
out = draw_vqa_overlay(img, parsed)
|
||||
assert out.size == img.size
|
||||
assert out.tobytes() != img.tobytes()
|
||||
|
||||
|
||||
def test_draw_overlay_non_spatial_leaves_image_unchanged():
|
||||
img = _blank()
|
||||
parsed = parse_vqa_answer('{"label": "cubes", "count": 2}')
|
||||
out = draw_vqa_overlay(img, parsed)
|
||||
assert out.tobytes() == img.tobytes()
|
||||
|
||||
|
||||
def test_draw_overlay_tolerates_malformed_coordinates():
|
||||
img = _blank()
|
||||
# bbox with the wrong arity must not raise.
|
||||
out = draw_vqa_overlay(img, {"kind": "bbox", "payload": {"detections": [{"bbox": [1, 2]}]}})
|
||||
assert out.size == img.size
|
||||
|
||||
|
||||
def test_observation_image_to_pil_from_batched_float_array():
|
||||
# (1, C, H, W) float array in [0, 1], the runtime observation shape.
|
||||
arr = np.zeros((1, 3, 24, 32), dtype=np.float32)
|
||||
pil = observation_image_to_pil(arr)
|
||||
assert pil.size == (32, 24)
|
||||
assert pil.mode == "RGB"
|
||||
Reference in New Issue
Block a user