Files
lerobot/tests/policies/smolvla/test_smolvla2_vqa_overlay.py
T
Pepijn c026aed8f8 feat(pi052): train VQA spatial answers in PaliGemma <loc> format
Spatial VQA answers (bbox / keypoint) were trained as pixel-coordinate
JSON, which fights PaliGemma's detection prior and leaks <loc>-token
salad at inference. Convert them to PaliGemma's native <locNNNN>
vocabulary instead so the LM head reuses that prior.

Training side (text_processor_pi052.py): a target turn whose content
parses as a bbox/keypoint answer is rewritten to <loc> text, using the
camera frame's native (H, W) from the observation and the preceding
image block. Non-spatial answers, subtask/memory targets and SmolVLA2
keep their JSON form — the dataset stays backbone-agnostic.

Runtime side (smolvla2/inference/vqa.py): parse_vqa_answer detects
<loc> answers (2 locs -> keypoint, 4 -> bbox), returning normalized
[0,1] coords with a normalized flag; draw_vqa_overlay denormalizes
against the chosen camera frame's pixel size.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-19 20:23:46 +02:00

229 lines
7.6 KiB
Python

#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the SmolVLA2 runtime's interactive-VQA helpers.
Covers camera selection, VQA-answer parsing, and the bounding-box /
keypoint overlay drawing — the pure functions, no model load.
"""
import numpy as np
import pytest
from lerobot.policies.smolvla2.inference.vqa import (
answer_has_overlay,
available_cameras,
camera_short_name,
draw_vqa_overlay,
observation_image_to_pil,
parse_vqa_answer,
prompt_camera_choice,
)
PIL = pytest.importorskip("PIL")
from PIL import Image # noqa: E402
# ---------------------------------------------------------------------------
# Camera selection
# ---------------------------------------------------------------------------
def test_available_cameras_extracts_and_sorts_image_keys():
observation = {
"observation.images.wrist": object(),
"observation.state": object(),
"observation.images.top": object(),
"task": "x",
}
assert available_cameras(observation) == [
"observation.images.top",
"observation.images.wrist",
]
def test_available_cameras_handles_none_and_empty():
assert available_cameras(None) == []
assert available_cameras({}) == []
def test_camera_short_name_strips_prefix():
assert camera_short_name("observation.images.top") == "top"
assert camera_short_name("top") == "top"
def test_prompt_camera_choice_single_camera_auto_selects():
cams = ["observation.images.top"]
# input_fn must never be called for a single-camera setup.
chosen = prompt_camera_choice(cams, input_fn=_boom, print_fn=lambda *_: None)
assert chosen == "observation.images.top"
def test_prompt_camera_choice_by_number():
cams = ["observation.images.top", "observation.images.wrist"]
chosen = prompt_camera_choice(cams, input_fn=lambda _: "2", print_fn=lambda *_: None)
assert chosen == "observation.images.wrist"
def test_prompt_camera_choice_by_name():
cams = ["observation.images.top", "observation.images.wrist"]
chosen = prompt_camera_choice(cams, input_fn=lambda _: "top", print_fn=lambda *_: None)
assert chosen == "observation.images.top"
def test_prompt_camera_choice_invalid_returns_none():
cams = ["observation.images.top", "observation.images.wrist"]
assert prompt_camera_choice(cams, input_fn=lambda _: "99", print_fn=lambda *_: None) is None
def _boom(*_args, **_kwargs):
raise AssertionError("input_fn should not be called")
# ---------------------------------------------------------------------------
# Answer parsing
# ---------------------------------------------------------------------------
def test_parse_bbox_answer():
answer = '{"detections": [{"label": "cube", "bbox_format": "xyxy", "bbox": [10, 20, 50, 80]}]}'
parsed = parse_vqa_answer(answer)
assert parsed["kind"] == "bbox"
assert answer_has_overlay(parsed)
def test_parse_keypoint_answer():
answer = '{"label": "blue cube", "point_format": "xy", "point": [120, 90]}'
parsed = parse_vqa_answer(answer)
assert parsed["kind"] == "keypoint"
assert answer_has_overlay(parsed)
def test_parse_count_answer_is_not_an_overlay():
parsed = parse_vqa_answer('{"label": "cubes", "count": 2}')
assert parsed["kind"] == "count"
assert not answer_has_overlay(parsed)
def test_parse_invalid_json_returns_none():
assert parse_vqa_answer("not json at all") is None
assert parse_vqa_answer("") is None
# A JSON array is valid JSON but not a VQA answer object.
assert parse_vqa_answer("[1, 2, 3]") is None
def test_parse_unknown_shape():
parsed = parse_vqa_answer('{"weird": "payload"}')
assert parsed["kind"] == "unknown"
assert not answer_has_overlay(parsed)
# ---------------------------------------------------------------------------
# Overlay drawing
# ---------------------------------------------------------------------------
def _blank(size=(160, 120)):
return Image.new("RGB", size, (0, 0, 0))
def test_draw_bbox_overlay_changes_pixels_and_preserves_size():
img = _blank()
parsed = parse_vqa_answer(
'{"detections": [{"label": "cube", "bbox_format": "xyxy", "bbox": [10, 20, 50, 80]}]}'
)
out = draw_vqa_overlay(img, parsed)
assert out.size == img.size
assert out.tobytes() != img.tobytes()
def test_draw_keypoint_overlay_changes_pixels():
img = _blank()
parsed = parse_vqa_answer('{"label": "cube", "point_format": "xy", "point": [80, 60]}')
out = draw_vqa_overlay(img, parsed)
assert out.size == img.size
assert out.tobytes() != img.tobytes()
def test_draw_overlay_non_spatial_leaves_image_unchanged():
img = _blank()
parsed = parse_vqa_answer('{"label": "cubes", "count": 2}')
out = draw_vqa_overlay(img, parsed)
assert out.tobytes() == img.tobytes()
def test_draw_overlay_tolerates_malformed_coordinates():
img = _blank()
# bbox with the wrong arity must not raise.
out = draw_vqa_overlay(img, {"kind": "bbox", "payload": {"detections": [{"bbox": [1, 2]}]}})
assert out.size == img.size
def test_observation_image_to_pil_from_batched_float_array():
# (1, C, H, W) float array in [0, 1], the runtime observation shape.
arr = np.zeros((1, 3, 24, 32), dtype=np.float32)
pil = observation_image_to_pil(arr)
assert pil.size == (32, 24)
assert pil.mode == "RGB"
# ---------------------------------------------------------------------------
# PaliGemma <loc>-format answers (PI052 trains spatial VQA in this vocab)
# ---------------------------------------------------------------------------
def test_parse_loc_keypoint_answer():
# <locY><locX> label — y=512/1023≈0.5, x=256/1023≈0.25
parsed = parse_vqa_answer("<loc0512><loc0256> blue cube")
assert parsed["kind"] == "keypoint"
assert parsed["normalized"] is True
x, y = parsed["payload"]["point"]
assert 0.24 < x < 0.26
assert 0.49 < y < 0.51
assert parsed["payload"]["label"] == "blue cube"
assert answer_has_overlay(parsed)
def test_parse_loc_bbox_answer():
# <locY0><locX0><locY1><locX1> label
parsed = parse_vqa_answer("<loc0100><loc0080><loc0400><loc0360> yellow cube")
assert parsed["kind"] == "bbox"
assert parsed["normalized"] is True
det = parsed["payload"]["detections"][0]
x1, y1, x2, y2 = det["bbox"]
assert x1 < x2 and y1 < y2
assert det["label"] == "yellow cube"
assert answer_has_overlay(parsed)
def test_parse_loc_multiple_boxes():
answer = "<loc0100><loc0080><loc0400><loc0360> cube ; <loc0200><loc0500><loc0600><loc0900> box"
parsed = parse_vqa_answer(answer)
assert parsed["kind"] == "bbox"
assert len(parsed["payload"]["detections"]) == 2
def test_parse_loc_takes_precedence_over_json():
# An answer with <loc> tokens is parsed as loc even if JSON-ish.
assert parse_vqa_answer('{"x": <loc0001><loc0002>}')["normalized"] is True
def test_draw_loc_overlay_denormalizes_to_pixels():
img = _blank((200, 100))
parsed = parse_vqa_answer("<loc0511><loc0511> cube") # ~centre
out = draw_vqa_overlay(img, parsed)
assert out.size == img.size
assert out.tobytes() != img.tobytes()