mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 20:50:02 +00:00
26cb38a7d0
Three additions to the SmolVLA2 interactive runtime: 1. Startup task picker — when no --task is given, the runtime lists the dataset's task strings as a numbered menu (plus a custom-task option) instead of silently waiting for the first stdin line. 2. Mode toggle — /action and /vlm slash commands flip a persistent run mode. /vlm pauses the whole action loop (HighLevelSubtaskFwd, LowLevelForward and DispatchAction gate on state["mode"]) and clears the action queue so the robot holds position; /action resumes it. The mode is shown in the state panel. 3. Interactive VQA — in /vlm mode a typed line is a VQA question. The new inference/vqa.py module asks which camera to ground on, runs the VLM on that single camera, and when the answer is a bbox/keypoint it draws the overlay, saves a PNG to ./vqa_overlays/ and auto-opens it. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
180 lines
5.8 KiB
Python
180 lines
5.8 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Tests for the SmolVLA2 runtime's interactive-VQA helpers.
|
|
|
|
Covers camera selection, VQA-answer parsing, and the bounding-box /
|
|
keypoint overlay drawing — the pure functions, no model load.
|
|
"""
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from lerobot.policies.smolvla2.inference.vqa import (
|
|
answer_has_overlay,
|
|
available_cameras,
|
|
camera_short_name,
|
|
draw_vqa_overlay,
|
|
observation_image_to_pil,
|
|
parse_vqa_answer,
|
|
prompt_camera_choice,
|
|
)
|
|
|
|
PIL = pytest.importorskip("PIL")
|
|
from PIL import Image # noqa: E402
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Camera selection
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_available_cameras_extracts_and_sorts_image_keys():
|
|
observation = {
|
|
"observation.images.wrist": object(),
|
|
"observation.state": object(),
|
|
"observation.images.top": object(),
|
|
"task": "x",
|
|
}
|
|
assert available_cameras(observation) == [
|
|
"observation.images.top",
|
|
"observation.images.wrist",
|
|
]
|
|
|
|
|
|
def test_available_cameras_handles_none_and_empty():
|
|
assert available_cameras(None) == []
|
|
assert available_cameras({}) == []
|
|
|
|
|
|
def test_camera_short_name_strips_prefix():
|
|
assert camera_short_name("observation.images.top") == "top"
|
|
assert camera_short_name("top") == "top"
|
|
|
|
|
|
def test_prompt_camera_choice_single_camera_auto_selects():
|
|
cams = ["observation.images.top"]
|
|
# input_fn must never be called for a single-camera setup.
|
|
chosen = prompt_camera_choice(cams, input_fn=_boom, print_fn=lambda *_: None)
|
|
assert chosen == "observation.images.top"
|
|
|
|
|
|
def test_prompt_camera_choice_by_number():
|
|
cams = ["observation.images.top", "observation.images.wrist"]
|
|
chosen = prompt_camera_choice(cams, input_fn=lambda _: "2", print_fn=lambda *_: None)
|
|
assert chosen == "observation.images.wrist"
|
|
|
|
|
|
def test_prompt_camera_choice_by_name():
|
|
cams = ["observation.images.top", "observation.images.wrist"]
|
|
chosen = prompt_camera_choice(cams, input_fn=lambda _: "top", print_fn=lambda *_: None)
|
|
assert chosen == "observation.images.top"
|
|
|
|
|
|
def test_prompt_camera_choice_invalid_returns_none():
|
|
cams = ["observation.images.top", "observation.images.wrist"]
|
|
assert prompt_camera_choice(cams, input_fn=lambda _: "99", print_fn=lambda *_: None) is None
|
|
|
|
|
|
def _boom(*_args, **_kwargs):
|
|
raise AssertionError("input_fn should not be called")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Answer parsing
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_parse_bbox_answer():
|
|
answer = '{"detections": [{"label": "cube", "bbox_format": "xyxy", "bbox": [10, 20, 50, 80]}]}'
|
|
parsed = parse_vqa_answer(answer)
|
|
assert parsed["kind"] == "bbox"
|
|
assert answer_has_overlay(parsed)
|
|
|
|
|
|
def test_parse_keypoint_answer():
|
|
answer = '{"label": "blue cube", "point_format": "xy", "point": [120, 90]}'
|
|
parsed = parse_vqa_answer(answer)
|
|
assert parsed["kind"] == "keypoint"
|
|
assert answer_has_overlay(parsed)
|
|
|
|
|
|
def test_parse_count_answer_is_not_an_overlay():
|
|
parsed = parse_vqa_answer('{"label": "cubes", "count": 2}')
|
|
assert parsed["kind"] == "count"
|
|
assert not answer_has_overlay(parsed)
|
|
|
|
|
|
def test_parse_invalid_json_returns_none():
|
|
assert parse_vqa_answer("not json at all") is None
|
|
assert parse_vqa_answer("") is None
|
|
# A JSON array is valid JSON but not a VQA answer object.
|
|
assert parse_vqa_answer("[1, 2, 3]") is None
|
|
|
|
|
|
def test_parse_unknown_shape():
|
|
parsed = parse_vqa_answer('{"weird": "payload"}')
|
|
assert parsed["kind"] == "unknown"
|
|
assert not answer_has_overlay(parsed)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Overlay drawing
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _blank(size=(160, 120)):
|
|
return Image.new("RGB", size, (0, 0, 0))
|
|
|
|
|
|
def test_draw_bbox_overlay_changes_pixels_and_preserves_size():
|
|
img = _blank()
|
|
parsed = parse_vqa_answer(
|
|
'{"detections": [{"label": "cube", "bbox_format": "xyxy", "bbox": [10, 20, 50, 80]}]}'
|
|
)
|
|
out = draw_vqa_overlay(img, parsed)
|
|
assert out.size == img.size
|
|
assert out.tobytes() != img.tobytes()
|
|
|
|
|
|
def test_draw_keypoint_overlay_changes_pixels():
|
|
img = _blank()
|
|
parsed = parse_vqa_answer('{"label": "cube", "point_format": "xy", "point": [80, 60]}')
|
|
out = draw_vqa_overlay(img, parsed)
|
|
assert out.size == img.size
|
|
assert out.tobytes() != img.tobytes()
|
|
|
|
|
|
def test_draw_overlay_non_spatial_leaves_image_unchanged():
|
|
img = _blank()
|
|
parsed = parse_vqa_answer('{"label": "cubes", "count": 2}')
|
|
out = draw_vqa_overlay(img, parsed)
|
|
assert out.tobytes() == img.tobytes()
|
|
|
|
|
|
def test_draw_overlay_tolerates_malformed_coordinates():
|
|
img = _blank()
|
|
# bbox with the wrong arity must not raise.
|
|
out = draw_vqa_overlay(img, {"kind": "bbox", "payload": {"detections": [{"bbox": [1, 2]}]}})
|
|
assert out.size == img.size
|
|
|
|
|
|
def test_observation_image_to_pil_from_batched_float_array():
|
|
# (1, C, H, W) float array in [0, 1], the runtime observation shape.
|
|
arr = np.zeros((1, 3, 24, 32), dtype=np.float32)
|
|
pil = observation_image_to_pil(arr)
|
|
assert pil.size == (32, 24)
|
|
assert pil.mode == "RGB"
|