mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-19 18:49:52 +00:00
feat(smolvla2): startup task picker, /vlm mode toggle, interactive VQA overlay
Three additions to the SmolVLA2 interactive runtime: 1. Startup task picker — when no --task is given, the runtime lists the dataset's task strings as a numbered menu (plus a custom-task option) instead of silently waiting for the first stdin line. 2. Mode toggle — /action and /vlm slash commands flip a persistent run mode. /vlm pauses the whole action loop (HighLevelSubtaskFwd, LowLevelForward and DispatchAction gate on state["mode"]) and clears the action queue so the robot holds position; /action resumes it. The mode is shown in the state panel. 3. Interactive VQA — in /vlm mode a typed line is a VQA question. The new inference/vqa.py module asks which camera to ground on, runs the VLM on that single camera, and when the answer is a bbox/keypoint it draws the overlay, saves a PNG to ./vqa_overlays/ and auto-opens it. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,179 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for the SmolVLA2 runtime's interactive-VQA helpers.
|
||||
|
||||
Covers camera selection, VQA-answer parsing, and the bounding-box /
|
||||
keypoint overlay drawing — the pure functions, no model load.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from lerobot.policies.smolvla2.inference.vqa import (
|
||||
answer_has_overlay,
|
||||
available_cameras,
|
||||
camera_short_name,
|
||||
draw_vqa_overlay,
|
||||
observation_image_to_pil,
|
||||
parse_vqa_answer,
|
||||
prompt_camera_choice,
|
||||
)
|
||||
|
||||
PIL = pytest.importorskip("PIL")
|
||||
from PIL import Image # noqa: E402
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Camera selection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_available_cameras_extracts_and_sorts_image_keys():
|
||||
observation = {
|
||||
"observation.images.wrist": object(),
|
||||
"observation.state": object(),
|
||||
"observation.images.top": object(),
|
||||
"task": "x",
|
||||
}
|
||||
assert available_cameras(observation) == [
|
||||
"observation.images.top",
|
||||
"observation.images.wrist",
|
||||
]
|
||||
|
||||
|
||||
def test_available_cameras_handles_none_and_empty():
|
||||
assert available_cameras(None) == []
|
||||
assert available_cameras({}) == []
|
||||
|
||||
|
||||
def test_camera_short_name_strips_prefix():
|
||||
assert camera_short_name("observation.images.top") == "top"
|
||||
assert camera_short_name("top") == "top"
|
||||
|
||||
|
||||
def test_prompt_camera_choice_single_camera_auto_selects():
|
||||
cams = ["observation.images.top"]
|
||||
# input_fn must never be called for a single-camera setup.
|
||||
chosen = prompt_camera_choice(cams, input_fn=_boom, print_fn=lambda *_: None)
|
||||
assert chosen == "observation.images.top"
|
||||
|
||||
|
||||
def test_prompt_camera_choice_by_number():
|
||||
cams = ["observation.images.top", "observation.images.wrist"]
|
||||
chosen = prompt_camera_choice(cams, input_fn=lambda _: "2", print_fn=lambda *_: None)
|
||||
assert chosen == "observation.images.wrist"
|
||||
|
||||
|
||||
def test_prompt_camera_choice_by_name():
|
||||
cams = ["observation.images.top", "observation.images.wrist"]
|
||||
chosen = prompt_camera_choice(cams, input_fn=lambda _: "top", print_fn=lambda *_: None)
|
||||
assert chosen == "observation.images.top"
|
||||
|
||||
|
||||
def test_prompt_camera_choice_invalid_returns_none():
|
||||
cams = ["observation.images.top", "observation.images.wrist"]
|
||||
assert prompt_camera_choice(cams, input_fn=lambda _: "99", print_fn=lambda *_: None) is None
|
||||
|
||||
|
||||
def _boom(*_args, **_kwargs):
|
||||
raise AssertionError("input_fn should not be called")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Answer parsing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_parse_bbox_answer():
|
||||
answer = '{"detections": [{"label": "cube", "bbox_format": "xyxy", "bbox": [10, 20, 50, 80]}]}'
|
||||
parsed = parse_vqa_answer(answer)
|
||||
assert parsed["kind"] == "bbox"
|
||||
assert answer_has_overlay(parsed)
|
||||
|
||||
|
||||
def test_parse_keypoint_answer():
|
||||
answer = '{"label": "blue cube", "point_format": "xy", "point": [120, 90]}'
|
||||
parsed = parse_vqa_answer(answer)
|
||||
assert parsed["kind"] == "keypoint"
|
||||
assert answer_has_overlay(parsed)
|
||||
|
||||
|
||||
def test_parse_count_answer_is_not_an_overlay():
|
||||
parsed = parse_vqa_answer('{"label": "cubes", "count": 2}')
|
||||
assert parsed["kind"] == "count"
|
||||
assert not answer_has_overlay(parsed)
|
||||
|
||||
|
||||
def test_parse_invalid_json_returns_none():
|
||||
assert parse_vqa_answer("not json at all") is None
|
||||
assert parse_vqa_answer("") is None
|
||||
# A JSON array is valid JSON but not a VQA answer object.
|
||||
assert parse_vqa_answer("[1, 2, 3]") is None
|
||||
|
||||
|
||||
def test_parse_unknown_shape():
|
||||
parsed = parse_vqa_answer('{"weird": "payload"}')
|
||||
assert parsed["kind"] == "unknown"
|
||||
assert not answer_has_overlay(parsed)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Overlay drawing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _blank(size=(160, 120)):
|
||||
return Image.new("RGB", size, (0, 0, 0))
|
||||
|
||||
|
||||
def test_draw_bbox_overlay_changes_pixels_and_preserves_size():
|
||||
img = _blank()
|
||||
parsed = parse_vqa_answer(
|
||||
'{"detections": [{"label": "cube", "bbox_format": "xyxy", "bbox": [10, 20, 50, 80]}]}'
|
||||
)
|
||||
out = draw_vqa_overlay(img, parsed)
|
||||
assert out.size == img.size
|
||||
assert out.tobytes() != img.tobytes()
|
||||
|
||||
|
||||
def test_draw_keypoint_overlay_changes_pixels():
|
||||
img = _blank()
|
||||
parsed = parse_vqa_answer('{"label": "cube", "point_format": "xy", "point": [80, 60]}')
|
||||
out = draw_vqa_overlay(img, parsed)
|
||||
assert out.size == img.size
|
||||
assert out.tobytes() != img.tobytes()
|
||||
|
||||
|
||||
def test_draw_overlay_non_spatial_leaves_image_unchanged():
|
||||
img = _blank()
|
||||
parsed = parse_vqa_answer('{"label": "cubes", "count": 2}')
|
||||
out = draw_vqa_overlay(img, parsed)
|
||||
assert out.tobytes() == img.tobytes()
|
||||
|
||||
|
||||
def test_draw_overlay_tolerates_malformed_coordinates():
|
||||
img = _blank()
|
||||
# bbox with the wrong arity must not raise.
|
||||
out = draw_vqa_overlay(img, {"kind": "bbox", "payload": {"detections": [{"bbox": [1, 2]}]}})
|
||||
assert out.size == img.size
|
||||
|
||||
|
||||
def test_observation_image_to_pil_from_batched_float_array():
|
||||
# (1, C, H, W) float array in [0, 1], the runtime observation shape.
|
||||
arr = np.zeros((1, 3, 24, 32), dtype=np.float32)
|
||||
pil = observation_image_to_pil(arr)
|
||||
assert pil.size == (32, 24)
|
||||
assert pil.mode == "RGB"
|
||||
Reference in New Issue
Block a user