feat(smolvla2): startup task picker, /vlm mode toggle, interactive VQA overlay

Three additions to the SmolVLA2 interactive runtime:

1. Startup task picker — when no --task is given, the runtime lists the
   dataset's task strings as a numbered menu (plus a custom-task option)
   instead of silently waiting for the first stdin line.

2. Mode toggle — /action and /vlm slash commands flip a persistent run
   mode. /vlm pauses the whole action loop (HighLevelSubtaskFwd,
   LowLevelForward and DispatchAction gate on state["mode"]) and clears
   the action queue so the robot holds position; /action resumes it.
   The mode is shown in the state panel.

3. Interactive VQA — in /vlm mode a typed line is a VQA question. The
   new inference/vqa.py module asks which camera to ground on, runs the
   VLM on that single camera, and when the answer is a bbox/keypoint it
   draws the overlay, saves a PNG to ./vqa_overlays/ and auto-opens it.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-18 11:20:57 +02:00
parent bfb8cfb432
commit 26cb38a7d0
7 changed files with 734 additions and 5 deletions
@@ -16,11 +16,16 @@
Reads non-blocking stdin lines, classifies each one heuristically:
"stop" / "quit" / "exit" → state["stop"] = True
"/action" / "/vlm" → set state["mode"]
ends with "?" → user_vqa_query event
starts with "task:" or first line → set runtime task
anything else → user_interjection event
Plugged into the runtime via ``event_collector=StdinReader().poll``.
Note: the shipped CLI (``lerobot-smolvla2-runtime``) drives stdin
directly in its REPL / autonomous loops and does *not* wire this
collector; it's kept as the documented embedding hook and for tests.
"""
from __future__ import annotations
@@ -70,6 +75,19 @@ class StdinReader:
state["stop"] = True
return
# Slash commands flip the run mode. ``/vlm`` pauses the action
# loop (the action steps gate on ``state["mode"]``); ``/action``
# resumes it.
if lower in {"/action", "/act"}:
state["mode"] = "action"
return
if lower in {"/vlm", "/vqa"}:
state["mode"] = "vlm"
queue = state.get("action_queue")
if hasattr(queue, "clear"):
queue.clear()
return
# First non-control line sets the task if no task is active.
if not state.get("task"):
task = line[5:].strip() if lower.startswith("task:") else line
@@ -33,6 +33,9 @@ Stable keys (read by multiple steps):
events_this_tick list[str] triggers consumed this tick
_tick Tick current tick (set by the loop)
mode str "action" (run the robot) | "vlm" (VQA only,
action loop paused)
log_lines list[str] human-readable status lines printed each tick
"""
@@ -54,6 +57,7 @@ def initial_runtime_state(task: str | None = None) -> dict[str, Any]:
"tool_calls_pending": [],
"events_this_tick": [],
"log_lines": [],
"mode": "action",
"stop": False,
}
@@ -91,6 +91,10 @@ class LowLevelForward(InferenceStep):
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
if self.policy is None or self.observation_provider is None:
return None
# ``/vlm`` mode pauses the whole action loop so the robot holds
# position while the operator probes the VLM with VQA.
if state.get("mode", "action") != "action":
return None
if not state.get("task"):
return None
@@ -197,6 +201,12 @@ class DispatchAction(InferenceStep):
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
import time as _time # noqa: PLC0415
# ``/vlm`` mode pauses dispatch — the robot holds its last
# commanded position while the operator runs VQA.
if state.get("mode", "action") != "action":
self._last_dispatch_t = None
return None
queue = state.get("action_queue")
if not queue:
# Reset wall-clock anchor when the queue is empty so the
@@ -366,6 +376,10 @@ class HighLevelSubtaskFwd(InferenceStep):
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
if self.policy is None or not state.get("task"):
return None
# ``/vlm`` mode pauses subtask generation along with the rest of
# the action loop.
if state.get("mode", "action") != "action":
return None
# Gate to chunk boundaries: only generate a fresh subtask when
# the action queue is empty (i.e. right before LowLevelForward
# refreshes the chunk). ``select_message`` takes ~2 s on MPS,
@@ -91,7 +91,15 @@ def make_state_panel(state: dict[str, Any]) -> Any:
(str(len(pending)), "bold magenta"),
)
table.add_row("", footer)
return Panel(table, title="[bold]SmolVLA2 state[/]", border_style="cyan")
run_mode = state.get("mode", "action")
mode_tag = (
"[green]action[/]" if run_mode == "action" else "[yellow]vlm (paused)[/]"
)
return Panel(
table,
title=f"[bold]SmolVLA2 state[/] · mode: {mode_tag}",
border_style="cyan",
)
def print_user_line(console: Any, line: str) -> None:
@@ -0,0 +1,339 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Interactive VQA for the SmolVLA2 runtime.
In ``/vlm`` mode a typed line is treated as a VQA question. This module
runs the full interactive flow:
1. pull the current observation and list available cameras,
2. ask the operator which camera to ground the question on,
3. generate the answer with the VLM conditioned on that one camera,
4. parse the JSON answer; if it carries a bounding box (``bbox``) or a
point (``keypoint``), draw the overlay on the camera frame, save a
PNG to ``./vqa_overlays/`` and auto-open it.
VQA answer schemas mirror the annotation pipeline's ``VQA_ANSWER_SHAPES``
(see ``lerobot.annotations.steerable_pipeline.validator``):
* ``bbox`` — ``{"detections": [{"label", "bbox_format": "xyxy",
"bbox": [x1, y1, x2, y2]}, ...]}``
* ``keypoint`` — ``{"label", "point_format": "xy", "point": [x, y]}``
* ``count`` / ``attribute`` / ``spatial`` — text-only, no overlay.
"""
from __future__ import annotations
import json
import logging
import os
import subprocess
import sys
import time
import webbrowser
from pathlib import Path
from typing import Any
from .runtime_state import push_log
logger = logging.getLogger(__name__)
_IMAGE_PREFIX = "observation.images."
# Iteration order for shape matching — most specific keys first so an
# answer is classified deterministically.
_SHAPE_ORDER = ("bbox", "keypoint", "count", "attribute", "spatial")
_BBOX_COLOR = (255, 64, 64)
_POINT_COLOR = (64, 220, 64)
# ---------------------------------------------------------------------------
# Camera selection
# ---------------------------------------------------------------------------
def available_cameras(observation: dict | None) -> list[str]:
"""Return the sorted ``observation.images.*`` keys present in ``observation``."""
if not observation:
return []
return sorted(k for k in observation if isinstance(k, str) and k.startswith(_IMAGE_PREFIX))
def camera_short_name(camera_key: str) -> str:
"""Strip the ``observation.images.`` prefix for display."""
return camera_key[len(_IMAGE_PREFIX) :] if camera_key.startswith(_IMAGE_PREFIX) else camera_key
def prompt_camera_choice(
cameras: list[str],
*,
input_fn: Any = input,
print_fn: Any = print,
) -> str | None:
"""Ask the operator which camera to ground a VQA question on.
Accepts either the menu number or the (short or full) camera name.
A single-camera setup auto-selects without prompting. Returns the
chosen ``observation.images.*`` key, or ``None`` if the operator
cancels / gives an invalid answer.
"""
if not cameras:
return None
if len(cameras) == 1:
return cameras[0]
print_fn("Which camera should I look at?")
for i, cam in enumerate(cameras, 1):
print_fn(f" [{i}] {camera_short_name(cam)}")
try:
raw = str(input_fn("camera> ")).strip()
except (EOFError, KeyboardInterrupt):
return None
if not raw:
return cameras[0]
if raw.isdigit():
idx = int(raw) - 1
return cameras[idx] if 0 <= idx < len(cameras) else None
for cam in cameras:
if raw == cam or raw == camera_short_name(cam):
return cam
return None
# ---------------------------------------------------------------------------
# Answer parsing
# ---------------------------------------------------------------------------
def parse_vqa_answer(answer: str) -> dict | None:
"""Parse a VQA answer string into ``{"kind", "payload"}``.
``kind`` is one of the ``VQA_ANSWER_SHAPES`` names (``bbox``,
``keypoint``, ``count``, ``attribute``, ``spatial``) or ``"unknown"``
when the JSON doesn't match any known shape. Returns ``None`` when
the answer is not parseable JSON / not a JSON object.
"""
if not answer or not answer.strip():
return None
try:
payload = json.loads(answer)
except (ValueError, TypeError):
return None
if not isinstance(payload, dict):
return None
try:
from lerobot.annotations.steerable_pipeline.validator import ( # noqa: PLC0415
VQA_ANSWER_SHAPES,
)
shapes = VQA_ANSWER_SHAPES
except ImportError: # pragma: no cover - annotation extra not installed
shapes = {
"bbox": {"detections"},
"keypoint": {"label", "point_format", "point"},
"count": {"label", "count"},
"attribute": {"label", "attribute", "value"},
"spatial": {"subject", "relation", "object"},
}
keys = set(payload)
for kind in _SHAPE_ORDER:
required = shapes.get(kind)
if required and required <= keys:
return {"kind": kind, "payload": payload}
return {"kind": "unknown", "payload": payload}
def answer_has_overlay(parsed: dict | None) -> bool:
"""True iff ``parsed`` carries drawable spatial coordinates."""
return bool(parsed) and parsed.get("kind") in ("bbox", "keypoint")
# ---------------------------------------------------------------------------
# Overlay drawing
# ---------------------------------------------------------------------------
def observation_image_to_pil(image_tensor: Any) -> Any:
"""Convert an ``observation.images.*`` tensor to a PIL RGB image.
The runtime observation stores images as ``(1, C, H, W)`` (or
``(C, H, W)``) float tensors in ``[0, 1]``. Reuses
``image_array_to_pil_image`` which handles the CHW→HWC transpose and
the float→uint8 scaling.
"""
from lerobot.datasets.image_writer import image_array_to_pil_image # noqa: PLC0415
arr = image_tensor
if hasattr(arr, "detach"):
arr = arr.detach().cpu()
if hasattr(arr, "numpy"):
arr = arr.numpy()
while arr.ndim > 3: # drop leading batch dim(s)
arr = arr[0]
return image_array_to_pil_image(arr).convert("RGB")
def draw_vqa_overlay(image: Any, parsed: dict) -> Any:
"""Draw ``bbox`` / ``keypoint`` answers onto a copy of ``image``.
Non-spatial answers (``count`` / ``attribute`` / ``spatial`` /
``unknown``) are returned as an unmodified copy.
"""
from PIL import ImageDraw # noqa: PLC0415
img = image.convert("RGB").copy()
kind = parsed.get("kind")
payload = parsed.get("payload") or {}
draw = ImageDraw.Draw(img)
if kind == "bbox":
for det in payload.get("detections") or []:
if not isinstance(det, dict):
continue
box = det.get("bbox")
if not (isinstance(box, list | tuple) and len(box) == 4):
continue
try:
x1, y1, x2, y2 = (float(v) for v in box)
except (TypeError, ValueError):
continue
draw.rectangle([x1, y1, x2, y2], outline=_BBOX_COLOR, width=3)
label = str(det.get("label", "")).strip()
if label:
draw.text((x1 + 3, max(0.0, y1 - 12)), label, fill=_BBOX_COLOR)
elif kind == "keypoint":
point = payload.get("point")
if isinstance(point, list | tuple) and len(point) == 2:
try:
x, y = float(point[0]), float(point[1])
except (TypeError, ValueError):
return img
r = 6
draw.ellipse([x - r, y - r, x + r, y + r], outline=_POINT_COLOR, width=3)
draw.line([x - 2 * r, y, x + 2 * r, y], fill=_POINT_COLOR, width=2)
draw.line([x, y - 2 * r, x, y + 2 * r], fill=_POINT_COLOR, width=2)
label = str(payload.get("label", "")).strip()
if label:
draw.text((x + r + 3, y - r), label, fill=_POINT_COLOR)
return img
def _open_file(path: Path) -> None:
"""Best-effort open ``path`` in the OS default viewer."""
try:
if sys.platform == "darwin":
subprocess.run(["open", str(path)], check=False)
elif sys.platform.startswith("linux"):
subprocess.run(["xdg-open", str(path)], check=False)
elif os.name == "nt":
os.startfile(str(path)) # type: ignore[attr-defined] # noqa: S606
else: # pragma: no cover - exotic platform
webbrowser.open(path.resolve().as_uri())
except Exception as exc: # noqa: BLE001
logger.debug("could not auto-open %s: %s", path, exc)
def save_and_open_overlay(image: Any, out_dir: str | Path = "./vqa_overlays") -> Path:
"""Save ``image`` as a timestamped PNG under ``out_dir`` and auto-open it."""
out = Path(out_dir)
out.mkdir(parents=True, exist_ok=True)
path = out / f"vqa_{int(time.time() * 1000)}.png"
image.save(path)
_open_file(path)
return path
# ---------------------------------------------------------------------------
# Orchestrator
# ---------------------------------------------------------------------------
def handle_vqa_query(
*,
policy: Any,
observation_provider: Any,
question: str,
state: dict[str, Any],
input_fn: Any = input,
print_fn: Any = print,
) -> None:
"""Run one interactive VQA question end to end.
Called synchronously from the input layer while the runtime is in
``/vlm`` mode (the action loop is gated off, so the policy is not in
concurrent use). All progress is reported via :func:`push_log` so it
shows up in the state panel's scrollback.
"""
from .steps import _generate_with_policy, _msgs_for_vqa # noqa: PLC0415
if policy is None or not hasattr(policy, "select_message"):
push_log(state, " [warn] vqa: policy has no select_message — skipping")
return
observation: dict | None = None
if observation_provider is not None:
try:
observation = observation_provider()
except Exception as exc: # noqa: BLE001
logger.debug("observation_provider raised %s", exc)
cameras = available_cameras(observation)
chosen: str | None = None
if cameras:
chosen = prompt_camera_choice(cameras, input_fn=input_fn, print_fn=print_fn)
if chosen is None:
push_log(state, " [info] vqa cancelled — no camera selected")
return
push_log(state, f" vqa camera: {camera_short_name(chosen)}")
else:
push_log(state, " [info] vqa: no camera available — answering text-only")
# Ground the question on the chosen camera only — filter the
# observation to that one image (+ proprio state) so the VLM
# prefix matches the single-image ``ask_vqa_*`` training recipe.
vqa_obs: dict | None = None
if observation is not None and chosen is not None:
vqa_obs = {chosen: observation[chosen]}
if "observation.state" in observation:
vqa_obs["observation.state"] = observation["observation.state"]
answer = _generate_with_policy(
policy,
_msgs_for_vqa(question),
observation=vqa_obs,
state=state,
label="vqa gen",
)
if not answer:
push_log(state, " [info] vqa gen returned empty")
return
push_log(state, f" vqa: {answer}")
parsed = parse_vqa_answer(answer)
if not answer_has_overlay(parsed):
if parsed is None:
push_log(state, " [info] vqa answer is not JSON — no overlay")
return
if observation is None or chosen is None:
push_log(state, " [info] no camera image — cannot draw overlay")
return
try:
pil = observation_image_to_pil(observation[chosen])
overlay = draw_vqa_overlay(pil, parsed)
path = save_and_open_overlay(overlay)
push_log(state, f" vqa overlay saved: {path}")
except Exception as exc: # noqa: BLE001
logger.warning("vqa overlay failed: %s", exc, exc_info=logger.isEnabledFor(logging.DEBUG))
push_log(state, f" [warn] vqa overlay failed: {type(exc).__name__}: {exc}")
+171 -4
View File
@@ -898,6 +898,126 @@ def _build_robot_action_executor(
return _executor
def _dataset_task_strings(ds_meta: Any) -> list[str]:
"""Pull the unique task strings from a ``LeRobotDatasetMetadata``.
``ds_meta.tasks`` is a pandas DataFrame indexed by the task string;
return the index as a plain list (empty when no dataset / no tasks).
"""
if ds_meta is None:
return []
tasks = getattr(ds_meta, "tasks", None)
if tasks is None:
return []
try:
return [str(t) for t in list(tasks.index)]
except Exception: # noqa: BLE001
return []
def _select_task_interactively(ds_meta: Any, current_task: str | None) -> str | None:
"""Prompt the operator to pick a task from the dataset or type one.
Called at startup when no ``--task`` was given. Non-TTY / scripted
runs return ``current_task`` unchanged so the existing
"first stdin line becomes the task" behaviour is preserved.
"""
if current_task:
return current_task
if not sys.stdin.isatty():
return current_task
tasks = _dataset_task_strings(ds_meta)
if not tasks:
try:
typed = input("[smolvla2] Enter the task: ").strip()
except (EOFError, KeyboardInterrupt):
return current_task
return typed or current_task
print("[smolvla2] Select a task:", flush=True)
for i, task in enumerate(tasks, 1):
print(f" [{i}] {task}", flush=True)
print(" [c] type a custom task", flush=True)
try:
raw = input("task> ").strip()
except (EOFError, KeyboardInterrupt):
return current_task
if not raw:
return tasks[0]
if raw.lower() in {"c", "custom"}:
try:
return input("[smolvla2] Enter the task: ").strip() or current_task
except (EOFError, KeyboardInterrupt):
return current_task
if raw.isdigit():
idx = int(raw) - 1
if 0 <= idx < len(tasks):
return tasks[idx]
print("[smolvla2] invalid choice — using the first task", flush=True)
return tasks[0]
# Treat anything else as a custom task string typed directly.
return raw
def _print_runtime_help() -> None:
"""Print the slash-command reference."""
print(
"[smolvla2] commands:\n"
" /action run the robot (default mode)\n"
" /vlm pause the action loop; typed lines become VQA questions\n"
" /help show this help\n"
" task: <text> switch task (clears plan / memory / subtask)\n"
" rephrase: <text> reword the task in place\n"
" stop | quit | exit end the session",
flush=True,
)
def _handle_slash_command(runtime: Any, line: str) -> bool:
"""Handle ``/action`` / ``/vlm`` / ``/help``.
Returns ``True`` when ``line`` was a recognised command (and was
consumed), ``False`` otherwise.
"""
cmd = line.strip().lower()
if cmd in {"/action", "/act"}:
runtime.state["mode"] = "action"
print("[smolvla2] mode: action — robot running", flush=True)
return True
if cmd in {"/vlm", "/vqa"}:
runtime.state["mode"] = "vlm"
# Drop any queued chunk so no stale action fires while paused.
queue = runtime.state.get("action_queue")
if hasattr(queue, "clear"):
queue.clear()
print(
"[smolvla2] mode: vlm — action loop paused; type VQA questions",
flush=True,
)
return True
if cmd in {"/help", "/?"}:
_print_runtime_help()
return True
return False
def _run_vqa_query(runtime: Any, question: str) -> None:
"""Run one interactive VQA question against the runtime's policy.
Used by both loops when in ``/vlm`` mode — the action loop is paused
so the policy is free for a synchronous VQA call.
"""
from lerobot.policies.smolvla2.inference.vqa import handle_vqa_query # noqa: PLC0415
handle_vqa_query(
policy=runtime.policy,
observation_provider=runtime.observation_provider,
question=question,
state=runtime.state,
)
def _run_autonomous(
runtime: Any,
*,
@@ -963,7 +1083,8 @@ def _run_autonomous(
)
redraw()
print(
" [autonomous] type interjections / '?' questions on stdin, "
" [autonomous] type interjections / '?' questions on stdin; "
"/vlm for VQA mode, /action to resume, /help for commands, "
"'stop' or Ctrl+C to quit",
flush=True,
)
@@ -998,6 +1119,9 @@ def _run_autonomous(
lower = line.lower()
if lower in {"stop", "quit", "exit"}:
break
# Slash commands (/action, /vlm, /help) flip the run mode.
if _handle_slash_command(runtime, line):
continue
# ``task: <text>`` always overrides the active task — both
# at first set and to switch tasks mid-run. Without the
# prefix and with a task already set, an utterance becomes
@@ -1036,6 +1160,12 @@ def _run_autonomous(
if not runtime.state.get("task"):
runtime.set_task(line)
continue
# ``/vlm`` mode: the whole line is a VQA question, handled
# synchronously (the action loop is paused so the policy is
# not in concurrent use by the background runtime thread).
if runtime.state.get("mode", "action") == "vlm":
_run_vqa_query(runtime, line)
continue
if lower.endswith("?"):
runtime.state["recent_vqa_query"] = line
runtime.state.setdefault("events_this_tick", []).append("user_vqa_query")
@@ -1080,8 +1210,16 @@ def _make_state_panel_renderer(
def _redraw(robot_lines: list[str] | None = None) -> None:
console.clear()
console.rule(f"[bold]SmolVLA2[/] · {mode_label}", style="cyan")
st = runtime.state
run_mode = st.get("mode", "action")
mode_tag = (
"[green]mode: action[/]"
if run_mode == "action"
else "[yellow]mode: vlm (action loop paused)[/]"
)
console.rule(
f"[bold]SmolVLA2[/] · {mode_label} · {mode_tag}", style="cyan"
)
for key, label in (
("task", "task"),
("current_subtask", "subtask"),
@@ -1157,8 +1295,9 @@ def _make_state_panel_renderer(
console.print()
if not st.get("task"):
console.print(
" [dim]Type the task to begin. Lines ending in '?' are VQA, "
"anything else is an interjection. Type 'stop' to exit.[/]"
" [dim]Type the task to begin. /vlm switches to VQA mode, "
"/action resumes the robot, /help lists commands. "
"Type 'stop' to exit.[/]"
)
return _redraw
@@ -1259,6 +1398,12 @@ def main(argv: list[str] | None = None) -> int:
flush=True,
)
# No task yet (no --task, no canonical dataset task) — let the
# operator pick one from the dataset's task list or type a custom
# one. Non-TTY runs keep the "first stdin line is the task" path.
if not args.task:
args.task = _select_task_interactively(ds_meta, args.task)
observation_provider: Callable[[], dict | None] | None = None
robot_executor: Callable[[Any], None] | None = None
robot = None
@@ -1415,6 +1560,28 @@ def _run_repl(runtime: Any, *, initial_task: str | None, max_ticks: int | None)
if lower in {"stop", "quit", "exit"}:
break
# Slash commands (/action, /vlm, /help) flip the run mode.
if _handle_slash_command(runtime, line):
_redraw(last_logs)
continue
# ``/vlm`` mode: a typed line (that isn't a task command) is
# a VQA question — run it synchronously and skip the action
# pipeline tick entirely.
if (
runtime.state.get("task")
and runtime.state.get("mode", "action") == "vlm"
and not lower.startswith(("task:", "rephrase:"))
):
runtime.state["log_lines"] = []
_run_vqa_query(runtime, line)
last_logs = list(runtime.state.get("log_lines") or [])
_redraw(last_logs)
ticks_done += 1
if max_ticks is not None and ticks_done >= max_ticks:
break
continue
# Inject the user input as the right kind of event,
# then run a single pipeline tick to consume it.
if lower.startswith("task:"):
@@ -0,0 +1,179 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the SmolVLA2 runtime's interactive-VQA helpers.
Covers camera selection, VQA-answer parsing, and the bounding-box /
keypoint overlay drawing the pure functions, no model load.
"""
import numpy as np
import pytest
from lerobot.policies.smolvla2.inference.vqa import (
answer_has_overlay,
available_cameras,
camera_short_name,
draw_vqa_overlay,
observation_image_to_pil,
parse_vqa_answer,
prompt_camera_choice,
)
PIL = pytest.importorskip("PIL")
from PIL import Image # noqa: E402
# ---------------------------------------------------------------------------
# Camera selection
# ---------------------------------------------------------------------------
def test_available_cameras_extracts_and_sorts_image_keys():
observation = {
"observation.images.wrist": object(),
"observation.state": object(),
"observation.images.top": object(),
"task": "x",
}
assert available_cameras(observation) == [
"observation.images.top",
"observation.images.wrist",
]
def test_available_cameras_handles_none_and_empty():
assert available_cameras(None) == []
assert available_cameras({}) == []
def test_camera_short_name_strips_prefix():
assert camera_short_name("observation.images.top") == "top"
assert camera_short_name("top") == "top"
def test_prompt_camera_choice_single_camera_auto_selects():
cams = ["observation.images.top"]
# input_fn must never be called for a single-camera setup.
chosen = prompt_camera_choice(cams, input_fn=_boom, print_fn=lambda *_: None)
assert chosen == "observation.images.top"
def test_prompt_camera_choice_by_number():
cams = ["observation.images.top", "observation.images.wrist"]
chosen = prompt_camera_choice(cams, input_fn=lambda _: "2", print_fn=lambda *_: None)
assert chosen == "observation.images.wrist"
def test_prompt_camera_choice_by_name():
cams = ["observation.images.top", "observation.images.wrist"]
chosen = prompt_camera_choice(cams, input_fn=lambda _: "top", print_fn=lambda *_: None)
assert chosen == "observation.images.top"
def test_prompt_camera_choice_invalid_returns_none():
cams = ["observation.images.top", "observation.images.wrist"]
assert prompt_camera_choice(cams, input_fn=lambda _: "99", print_fn=lambda *_: None) is None
def _boom(*_args, **_kwargs):
raise AssertionError("input_fn should not be called")
# ---------------------------------------------------------------------------
# Answer parsing
# ---------------------------------------------------------------------------
def test_parse_bbox_answer():
answer = '{"detections": [{"label": "cube", "bbox_format": "xyxy", "bbox": [10, 20, 50, 80]}]}'
parsed = parse_vqa_answer(answer)
assert parsed["kind"] == "bbox"
assert answer_has_overlay(parsed)
def test_parse_keypoint_answer():
answer = '{"label": "blue cube", "point_format": "xy", "point": [120, 90]}'
parsed = parse_vqa_answer(answer)
assert parsed["kind"] == "keypoint"
assert answer_has_overlay(parsed)
def test_parse_count_answer_is_not_an_overlay():
parsed = parse_vqa_answer('{"label": "cubes", "count": 2}')
assert parsed["kind"] == "count"
assert not answer_has_overlay(parsed)
def test_parse_invalid_json_returns_none():
assert parse_vqa_answer("not json at all") is None
assert parse_vqa_answer("") is None
# A JSON array is valid JSON but not a VQA answer object.
assert parse_vqa_answer("[1, 2, 3]") is None
def test_parse_unknown_shape():
parsed = parse_vqa_answer('{"weird": "payload"}')
assert parsed["kind"] == "unknown"
assert not answer_has_overlay(parsed)
# ---------------------------------------------------------------------------
# Overlay drawing
# ---------------------------------------------------------------------------
def _blank(size=(160, 120)):
return Image.new("RGB", size, (0, 0, 0))
def test_draw_bbox_overlay_changes_pixels_and_preserves_size():
img = _blank()
parsed = parse_vqa_answer(
'{"detections": [{"label": "cube", "bbox_format": "xyxy", "bbox": [10, 20, 50, 80]}]}'
)
out = draw_vqa_overlay(img, parsed)
assert out.size == img.size
assert out.tobytes() != img.tobytes()
def test_draw_keypoint_overlay_changes_pixels():
img = _blank()
parsed = parse_vqa_answer('{"label": "cube", "point_format": "xy", "point": [80, 60]}')
out = draw_vqa_overlay(img, parsed)
assert out.size == img.size
assert out.tobytes() != img.tobytes()
def test_draw_overlay_non_spatial_leaves_image_unchanged():
img = _blank()
parsed = parse_vqa_answer('{"label": "cubes", "count": 2}')
out = draw_vqa_overlay(img, parsed)
assert out.tobytes() == img.tobytes()
def test_draw_overlay_tolerates_malformed_coordinates():
img = _blank()
# bbox with the wrong arity must not raise.
out = draw_vqa_overlay(img, {"kind": "bbox", "payload": {"detections": [{"bbox": [1, 2]}]}})
assert out.size == img.size
def test_observation_image_to_pil_from_batched_float_array():
# (1, C, H, W) float array in [0, 1], the runtime observation shape.
arr = np.zeros((1, 3, 24, 32), dtype=np.float32)
pil = observation_image_to_pil(arr)
assert pil.size == (32, 24)
assert pil.mode == "RGB"