feat(smolvla2): startup task picker, /vlm mode toggle, interactive VQA overlay

Three additions to the SmolVLA2 interactive runtime:

1. Startup task picker — when no --task is given, the runtime lists the
   dataset's task strings as a numbered menu (plus a custom-task option)
   instead of silently waiting for the first stdin line.

2. Mode toggle — /action and /vlm slash commands flip a persistent run
   mode. /vlm pauses the whole action loop (HighLevelSubtaskFwd,
   LowLevelForward and DispatchAction gate on state["mode"]) and clears
   the action queue so the robot holds position; /action resumes it.
   The mode is shown in the state panel.

3. Interactive VQA — in /vlm mode a typed line is a VQA question. The
   new inference/vqa.py module asks which camera to ground on, runs the
   VLM on that single camera, and when the answer is a bbox/keypoint it
   draws the overlay, saves a PNG to ./vqa_overlays/ and auto-opens it.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-18 11:20:57 +02:00
parent bfb8cfb432
commit 26cb38a7d0
7 changed files with 734 additions and 5 deletions
@@ -0,0 +1,179 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the SmolVLA2 runtime's interactive-VQA helpers.
Covers camera selection, VQA-answer parsing, and the bounding-box /
keypoint overlay drawing — the pure functions, no model load.
"""
import numpy as np
import pytest
from lerobot.policies.smolvla2.inference.vqa import (
answer_has_overlay,
available_cameras,
camera_short_name,
draw_vqa_overlay,
observation_image_to_pil,
parse_vqa_answer,
prompt_camera_choice,
)
PIL = pytest.importorskip("PIL")
from PIL import Image # noqa: E402
# ---------------------------------------------------------------------------
# Camera selection
# ---------------------------------------------------------------------------
def test_available_cameras_extracts_and_sorts_image_keys():
observation = {
"observation.images.wrist": object(),
"observation.state": object(),
"observation.images.top": object(),
"task": "x",
}
assert available_cameras(observation) == [
"observation.images.top",
"observation.images.wrist",
]
def test_available_cameras_handles_none_and_empty():
assert available_cameras(None) == []
assert available_cameras({}) == []
def test_camera_short_name_strips_prefix():
assert camera_short_name("observation.images.top") == "top"
assert camera_short_name("top") == "top"
def test_prompt_camera_choice_single_camera_auto_selects():
cams = ["observation.images.top"]
# input_fn must never be called for a single-camera setup.
chosen = prompt_camera_choice(cams, input_fn=_boom, print_fn=lambda *_: None)
assert chosen == "observation.images.top"
def test_prompt_camera_choice_by_number():
cams = ["observation.images.top", "observation.images.wrist"]
chosen = prompt_camera_choice(cams, input_fn=lambda _: "2", print_fn=lambda *_: None)
assert chosen == "observation.images.wrist"
def test_prompt_camera_choice_by_name():
cams = ["observation.images.top", "observation.images.wrist"]
chosen = prompt_camera_choice(cams, input_fn=lambda _: "top", print_fn=lambda *_: None)
assert chosen == "observation.images.top"
def test_prompt_camera_choice_invalid_returns_none():
cams = ["observation.images.top", "observation.images.wrist"]
assert prompt_camera_choice(cams, input_fn=lambda _: "99", print_fn=lambda *_: None) is None
def _boom(*_args, **_kwargs):
raise AssertionError("input_fn should not be called")
# ---------------------------------------------------------------------------
# Answer parsing
# ---------------------------------------------------------------------------
def test_parse_bbox_answer():
answer = '{"detections": [{"label": "cube", "bbox_format": "xyxy", "bbox": [10, 20, 50, 80]}]}'
parsed = parse_vqa_answer(answer)
assert parsed["kind"] == "bbox"
assert answer_has_overlay(parsed)
def test_parse_keypoint_answer():
answer = '{"label": "blue cube", "point_format": "xy", "point": [120, 90]}'
parsed = parse_vqa_answer(answer)
assert parsed["kind"] == "keypoint"
assert answer_has_overlay(parsed)
def test_parse_count_answer_is_not_an_overlay():
parsed = parse_vqa_answer('{"label": "cubes", "count": 2}')
assert parsed["kind"] == "count"
assert not answer_has_overlay(parsed)
def test_parse_invalid_json_returns_none():
assert parse_vqa_answer("not json at all") is None
assert parse_vqa_answer("") is None
# A JSON array is valid JSON but not a VQA answer object.
assert parse_vqa_answer("[1, 2, 3]") is None
def test_parse_unknown_shape():
parsed = parse_vqa_answer('{"weird": "payload"}')
assert parsed["kind"] == "unknown"
assert not answer_has_overlay(parsed)
# ---------------------------------------------------------------------------
# Overlay drawing
# ---------------------------------------------------------------------------
def _blank(size=(160, 120)):
return Image.new("RGB", size, (0, 0, 0))
def test_draw_bbox_overlay_changes_pixels_and_preserves_size():
img = _blank()
parsed = parse_vqa_answer(
'{"detections": [{"label": "cube", "bbox_format": "xyxy", "bbox": [10, 20, 50, 80]}]}'
)
out = draw_vqa_overlay(img, parsed)
assert out.size == img.size
assert out.tobytes() != img.tobytes()
def test_draw_keypoint_overlay_changes_pixels():
img = _blank()
parsed = parse_vqa_answer('{"label": "cube", "point_format": "xy", "point": [80, 60]}')
out = draw_vqa_overlay(img, parsed)
assert out.size == img.size
assert out.tobytes() != img.tobytes()
def test_draw_overlay_non_spatial_leaves_image_unchanged():
img = _blank()
parsed = parse_vqa_answer('{"label": "cubes", "count": 2}')
out = draw_vqa_overlay(img, parsed)
assert out.tobytes() == img.tobytes()
def test_draw_overlay_tolerates_malformed_coordinates():
img = _blank()
# bbox with the wrong arity must not raise.
out = draw_vqa_overlay(img, {"kind": "bbox", "payload": {"detections": [{"bbox": [1, 2]}]}})
assert out.size == img.size
def test_observation_image_to_pil_from_batched_float_array():
# (1, C, H, W) float array in [0, 1], the runtime observation shape.
arr = np.zeros((1, 3, 24, 32), dtype=np.float32)
pil = observation_image_to_pil(arr)
assert pil.size == (32, 24)
assert pil.mode == "RGB"