mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 04:30:10 +00:00
c026aed8f8
Spatial VQA answers (bbox / keypoint) were trained as pixel-coordinate JSON, which fights PaliGemma's detection prior and leaks <loc>-token salad at inference. Convert them to PaliGemma's native <locNNNN> vocabulary instead so the LM head reuses that prior. Training side (text_processor_pi052.py): a target turn whose content parses as a bbox/keypoint answer is rewritten to <loc> text, using the camera frame's native (H, W) from the observation and the preceding image block. Non-spatial answers, subtask/memory targets and SmolVLA2 keep their JSON form — the dataset stays backbone-agnostic. Runtime side (smolvla2/inference/vqa.py): parse_vqa_answer detects <loc> answers (2 locs -> keypoint, 4 -> bbox), returning normalized [0,1] coords with a normalized flag; draw_vqa_overlay denormalizes against the chosen camera frame's pixel size. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
229 lines
7.6 KiB
Python
229 lines
7.6 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Tests for the SmolVLA2 runtime's interactive-VQA helpers.
|
|
|
|
Covers camera selection, VQA-answer parsing, and the bounding-box /
|
|
keypoint overlay drawing — the pure functions, no model load.
|
|
"""
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from lerobot.policies.smolvla2.inference.vqa import (
|
|
answer_has_overlay,
|
|
available_cameras,
|
|
camera_short_name,
|
|
draw_vqa_overlay,
|
|
observation_image_to_pil,
|
|
parse_vqa_answer,
|
|
prompt_camera_choice,
|
|
)
|
|
|
|
PIL = pytest.importorskip("PIL")
|
|
from PIL import Image # noqa: E402
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Camera selection
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_available_cameras_extracts_and_sorts_image_keys():
|
|
observation = {
|
|
"observation.images.wrist": object(),
|
|
"observation.state": object(),
|
|
"observation.images.top": object(),
|
|
"task": "x",
|
|
}
|
|
assert available_cameras(observation) == [
|
|
"observation.images.top",
|
|
"observation.images.wrist",
|
|
]
|
|
|
|
|
|
def test_available_cameras_handles_none_and_empty():
|
|
assert available_cameras(None) == []
|
|
assert available_cameras({}) == []
|
|
|
|
|
|
def test_camera_short_name_strips_prefix():
|
|
assert camera_short_name("observation.images.top") == "top"
|
|
assert camera_short_name("top") == "top"
|
|
|
|
|
|
def test_prompt_camera_choice_single_camera_auto_selects():
|
|
cams = ["observation.images.top"]
|
|
# input_fn must never be called for a single-camera setup.
|
|
chosen = prompt_camera_choice(cams, input_fn=_boom, print_fn=lambda *_: None)
|
|
assert chosen == "observation.images.top"
|
|
|
|
|
|
def test_prompt_camera_choice_by_number():
|
|
cams = ["observation.images.top", "observation.images.wrist"]
|
|
chosen = prompt_camera_choice(cams, input_fn=lambda _: "2", print_fn=lambda *_: None)
|
|
assert chosen == "observation.images.wrist"
|
|
|
|
|
|
def test_prompt_camera_choice_by_name():
|
|
cams = ["observation.images.top", "observation.images.wrist"]
|
|
chosen = prompt_camera_choice(cams, input_fn=lambda _: "top", print_fn=lambda *_: None)
|
|
assert chosen == "observation.images.top"
|
|
|
|
|
|
def test_prompt_camera_choice_invalid_returns_none():
|
|
cams = ["observation.images.top", "observation.images.wrist"]
|
|
assert prompt_camera_choice(cams, input_fn=lambda _: "99", print_fn=lambda *_: None) is None
|
|
|
|
|
|
def _boom(*_args, **_kwargs):
|
|
raise AssertionError("input_fn should not be called")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Answer parsing
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_parse_bbox_answer():
|
|
answer = '{"detections": [{"label": "cube", "bbox_format": "xyxy", "bbox": [10, 20, 50, 80]}]}'
|
|
parsed = parse_vqa_answer(answer)
|
|
assert parsed["kind"] == "bbox"
|
|
assert answer_has_overlay(parsed)
|
|
|
|
|
|
def test_parse_keypoint_answer():
|
|
answer = '{"label": "blue cube", "point_format": "xy", "point": [120, 90]}'
|
|
parsed = parse_vqa_answer(answer)
|
|
assert parsed["kind"] == "keypoint"
|
|
assert answer_has_overlay(parsed)
|
|
|
|
|
|
def test_parse_count_answer_is_not_an_overlay():
|
|
parsed = parse_vqa_answer('{"label": "cubes", "count": 2}')
|
|
assert parsed["kind"] == "count"
|
|
assert not answer_has_overlay(parsed)
|
|
|
|
|
|
def test_parse_invalid_json_returns_none():
|
|
assert parse_vqa_answer("not json at all") is None
|
|
assert parse_vqa_answer("") is None
|
|
# A JSON array is valid JSON but not a VQA answer object.
|
|
assert parse_vqa_answer("[1, 2, 3]") is None
|
|
|
|
|
|
def test_parse_unknown_shape():
|
|
parsed = parse_vqa_answer('{"weird": "payload"}')
|
|
assert parsed["kind"] == "unknown"
|
|
assert not answer_has_overlay(parsed)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Overlay drawing
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _blank(size=(160, 120)):
|
|
return Image.new("RGB", size, (0, 0, 0))
|
|
|
|
|
|
def test_draw_bbox_overlay_changes_pixels_and_preserves_size():
|
|
img = _blank()
|
|
parsed = parse_vqa_answer(
|
|
'{"detections": [{"label": "cube", "bbox_format": "xyxy", "bbox": [10, 20, 50, 80]}]}'
|
|
)
|
|
out = draw_vqa_overlay(img, parsed)
|
|
assert out.size == img.size
|
|
assert out.tobytes() != img.tobytes()
|
|
|
|
|
|
def test_draw_keypoint_overlay_changes_pixels():
|
|
img = _blank()
|
|
parsed = parse_vqa_answer('{"label": "cube", "point_format": "xy", "point": [80, 60]}')
|
|
out = draw_vqa_overlay(img, parsed)
|
|
assert out.size == img.size
|
|
assert out.tobytes() != img.tobytes()
|
|
|
|
|
|
def test_draw_overlay_non_spatial_leaves_image_unchanged():
|
|
img = _blank()
|
|
parsed = parse_vqa_answer('{"label": "cubes", "count": 2}')
|
|
out = draw_vqa_overlay(img, parsed)
|
|
assert out.tobytes() == img.tobytes()
|
|
|
|
|
|
def test_draw_overlay_tolerates_malformed_coordinates():
|
|
img = _blank()
|
|
# bbox with the wrong arity must not raise.
|
|
out = draw_vqa_overlay(img, {"kind": "bbox", "payload": {"detections": [{"bbox": [1, 2]}]}})
|
|
assert out.size == img.size
|
|
|
|
|
|
def test_observation_image_to_pil_from_batched_float_array():
|
|
# (1, C, H, W) float array in [0, 1], the runtime observation shape.
|
|
arr = np.zeros((1, 3, 24, 32), dtype=np.float32)
|
|
pil = observation_image_to_pil(arr)
|
|
assert pil.size == (32, 24)
|
|
assert pil.mode == "RGB"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# PaliGemma <loc>-format answers (PI052 trains spatial VQA in this vocab)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_parse_loc_keypoint_answer():
|
|
# <locY><locX> label — y=512/1023≈0.5, x=256/1023≈0.25
|
|
parsed = parse_vqa_answer("<loc0512><loc0256> blue cube")
|
|
assert parsed["kind"] == "keypoint"
|
|
assert parsed["normalized"] is True
|
|
x, y = parsed["payload"]["point"]
|
|
assert 0.24 < x < 0.26
|
|
assert 0.49 < y < 0.51
|
|
assert parsed["payload"]["label"] == "blue cube"
|
|
assert answer_has_overlay(parsed)
|
|
|
|
|
|
def test_parse_loc_bbox_answer():
|
|
# <locY0><locX0><locY1><locX1> label
|
|
parsed = parse_vqa_answer("<loc0100><loc0080><loc0400><loc0360> yellow cube")
|
|
assert parsed["kind"] == "bbox"
|
|
assert parsed["normalized"] is True
|
|
det = parsed["payload"]["detections"][0]
|
|
x1, y1, x2, y2 = det["bbox"]
|
|
assert x1 < x2 and y1 < y2
|
|
assert det["label"] == "yellow cube"
|
|
assert answer_has_overlay(parsed)
|
|
|
|
|
|
def test_parse_loc_multiple_boxes():
|
|
answer = "<loc0100><loc0080><loc0400><loc0360> cube ; <loc0200><loc0500><loc0600><loc0900> box"
|
|
parsed = parse_vqa_answer(answer)
|
|
assert parsed["kind"] == "bbox"
|
|
assert len(parsed["payload"]["detections"]) == 2
|
|
|
|
|
|
def test_parse_loc_takes_precedence_over_json():
|
|
# An answer with <loc> tokens is parsed as loc even if JSON-ish.
|
|
assert parse_vqa_answer('{"x": <loc0001><loc0002>}')["normalized"] is True
|
|
|
|
|
|
def test_draw_loc_overlay_denormalizes_to_pixels():
|
|
img = _blank((200, 100))
|
|
parsed = parse_vqa_answer("<loc0511><loc0511> cube") # ~centre
|
|
out = draw_vqa_overlay(img, parsed)
|
|
assert out.size == img.size
|
|
assert out.tobytes() != img.tobytes()
|