mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-19 18:49:52 +00:00
feat(pi052): train VQA spatial answers in PaliGemma <loc> format
Spatial VQA answers (bbox / keypoint) were trained as pixel-coordinate JSON, which fights PaliGemma's detection prior and leaks <loc>-token salad at inference. Convert them to PaliGemma's native <locNNNN> vocabulary instead so the LM head reuses that prior. Training side (text_processor_pi052.py): a target turn whose content parses as a bbox/keypoint answer is rewritten to <loc> text, using the camera frame's native (H, W) from the observation and the preceding image block. Non-spatial answers, subtask/memory targets and SmolVLA2 keep their JSON form — the dataset stays backbone-agnostic. Runtime side (smolvla2/inference/vqa.py): parse_vqa_answer detects <loc> answers (2 locs -> keypoint, 4 -> bbox), returning normalized [0,1] coords with a normalized flag; draw_vqa_overlay denormalizes against the chosen camera frame's pixel size. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -36,6 +36,7 @@ Outputs:
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
@@ -234,6 +235,134 @@ def _sample_indices(value: Any, batch_size: int) -> list[int | None]:
|
||||
return [int(value)] * batch_size
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# VQA spatial answers → PaliGemma <loc> format (PI052 only)
|
||||
#
|
||||
# PaliGemma is pre-trained on detection / pointing with a ``<locNNNN>``
|
||||
# vocabulary (normalized [0, 1023]). The recipe's bbox / keypoint VQA
|
||||
# answers are stored as JSON with *pixel* coordinates. Training those in
|
||||
# ``<loc>`` form leverages PaliGemma's prior instead of fighting it (the
|
||||
# ``<loc>``-token salad). The conversion lives here — not in the dataset
|
||||
# — so the dataset stays backbone-agnostic (SmolVLA2 keeps the JSON).
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _camera_image_shapes(observation: dict[str, Any]) -> dict[str, tuple[int, int]]:
|
||||
"""Map each ``observation.images.*`` key to its native ``(height, width)``.
|
||||
|
||||
VQA pixel coordinates are relative to the camera frame's native
|
||||
resolution. PI052's input pipeline applies no spatial resize before
|
||||
this step, so the observation image tensors are still at that
|
||||
resolution — the correct reference for normalizing to ``<loc>``.
|
||||
"""
|
||||
shapes: dict[str, tuple[int, int]] = {}
|
||||
for key, value in (observation or {}).items():
|
||||
if not (isinstance(key, str) and key.startswith("observation.images.")):
|
||||
continue
|
||||
shape = getattr(value, "shape", None)
|
||||
if shape is None or len(shape) < 2:
|
||||
continue
|
||||
shapes[key] = (int(shape[-2]), int(shape[-1])) # (H, W); handles (B,C,H,W)/(C,H,W)
|
||||
return shapes
|
||||
|
||||
|
||||
def _loc_token(coord: float, dim: int) -> str:
|
||||
"""PaliGemma ``<locNNNN>`` for pixel ``coord`` on an axis of size ``dim``."""
|
||||
idx = round(float(coord) / dim * 1023) if dim > 0 else 0
|
||||
return f"<loc{max(0, min(1023, idx)):04d}>"
|
||||
|
||||
|
||||
def _vqa_answer_to_loc(answer: dict[str, Any], height: int, width: int) -> str | None:
|
||||
"""Convert a bbox / keypoint VQA answer dict to PaliGemma ``<loc>`` text.
|
||||
|
||||
PaliGemma convention: a point is ``<locY><locX> label``; a box is
|
||||
``<locY0><locX0><locY1><locX1> label`` (y before x, each index in
|
||||
[0, 1023]). Returns ``None`` for non-spatial answers (count /
|
||||
attribute / spatial-relation) — those keep their JSON form.
|
||||
"""
|
||||
point = answer.get("point")
|
||||
if isinstance(point, list | tuple) and len(point) == 2 and "point_format" in answer:
|
||||
try:
|
||||
x, y = float(point[0]), float(point[1])
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
label = str(answer.get("label", "")).strip()
|
||||
return f"{_loc_token(y, height)}{_loc_token(x, width)} {label}".strip()
|
||||
|
||||
detections = answer.get("detections")
|
||||
if isinstance(detections, list) and detections:
|
||||
parts: list[str] = []
|
||||
for det in detections:
|
||||
if not isinstance(det, dict):
|
||||
continue
|
||||
box = det.get("bbox")
|
||||
if not (isinstance(box, list | tuple) and len(box) == 4):
|
||||
continue
|
||||
try:
|
||||
x1, y1, x2, y2 = (float(v) for v in box)
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
label = str(det.get("label", "")).strip()
|
||||
toks = (
|
||||
f"{_loc_token(y1, height)}{_loc_token(x1, width)}"
|
||||
f"{_loc_token(y2, height)}{_loc_token(x2, width)}"
|
||||
)
|
||||
parts.append(f"{toks} {label}".strip())
|
||||
return " ; ".join(parts) if parts else None
|
||||
return None
|
||||
|
||||
|
||||
def _preceding_image_feature(messages: list[dict[str, Any]], idx: int) -> str | None:
|
||||
"""Camera ``feature`` of the nearest image block at or before ``idx``."""
|
||||
for j in range(min(idx, len(messages) - 1), -1, -1):
|
||||
content = messages[j].get("content")
|
||||
if not isinstance(content, list):
|
||||
continue
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "image":
|
||||
feature = block.get("feature")
|
||||
if isinstance(feature, str):
|
||||
return feature
|
||||
return None
|
||||
|
||||
|
||||
def _messages_vqa_to_loc(
|
||||
messages: list[dict[str, Any]],
|
||||
target_indices: list[int],
|
||||
image_shapes: dict[str, tuple[int, int]] | None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Rewrite bbox / keypoint VQA *target* answers from JSON to ``<loc>`` text.
|
||||
|
||||
Each target turn whose content parses as a spatial VQA answer is
|
||||
converted, using the camera frame found from the preceding image
|
||||
block. Non-spatial answers, subtask / memory targets (plain text →
|
||||
not JSON), and turns with no matching image shape are left untouched.
|
||||
"""
|
||||
if not image_shapes or not target_indices:
|
||||
return messages
|
||||
out = list(messages)
|
||||
for idx in target_indices:
|
||||
if not (0 <= idx < len(out)):
|
||||
continue
|
||||
content = out[idx].get("content")
|
||||
if not isinstance(content, str) or not content.strip():
|
||||
continue
|
||||
try:
|
||||
answer = json.loads(content)
|
||||
except (ValueError, TypeError):
|
||||
continue # subtask / memory targets are plain text — skip
|
||||
if not isinstance(answer, dict):
|
||||
continue
|
||||
feature = _preceding_image_feature(out, idx)
|
||||
if feature is None or feature not in image_shapes:
|
||||
continue
|
||||
h, w = image_shapes[feature]
|
||||
loc_text = _vqa_answer_to_loc(answer, h, w)
|
||||
if loc_text is not None:
|
||||
out[idx] = {**out[idx], "content": loc_text}
|
||||
return out
|
||||
|
||||
|
||||
def _format_messages(
|
||||
messages: list[dict[str, Any]],
|
||||
target_indices: list[int] | None = None,
|
||||
@@ -329,6 +458,9 @@ class PI052TextTokenizerStep(ProcessorStep):
|
||||
return transition
|
||||
|
||||
tokenizer = self._ensure_tokenizer()
|
||||
# Native camera resolutions — the reference frame for converting
|
||||
# VQA pixel coordinates to PaliGemma <loc> tokens.
|
||||
image_shapes = _camera_image_shapes(transition.get(TransitionKey.OBSERVATION) or {})
|
||||
if _is_batched_messages(messages):
|
||||
indices_iter = _sample_indices(complementary.get("index"), len(messages))
|
||||
encoded = [
|
||||
@@ -339,6 +471,7 @@ class PI052TextTokenizerStep(ProcessorStep):
|
||||
list(tgt_indices),
|
||||
complementary,
|
||||
sample_idx=int(s_idx) if s_idx is not None else None,
|
||||
image_shapes=image_shapes,
|
||||
)
|
||||
for msg, streams, tgt_indices, s_idx in zip(
|
||||
messages,
|
||||
@@ -358,6 +491,7 @@ class PI052TextTokenizerStep(ProcessorStep):
|
||||
list(complementary.get("target_message_indices") or []),
|
||||
complementary,
|
||||
sample_idx=sample_idx,
|
||||
image_shapes=image_shapes,
|
||||
)
|
||||
]
|
||||
|
||||
@@ -411,6 +545,7 @@ class PI052TextTokenizerStep(ProcessorStep):
|
||||
target_indices: list[int],
|
||||
complementary: dict[str, Any],
|
||||
sample_idx: int | None = None,
|
||||
image_shapes: dict[str, tuple[int, int]] | None = None,
|
||||
) -> tuple[Tensor, Tensor, Tensor, Tensor, str]:
|
||||
# Optional: drop non-target messages per the dropout config.
|
||||
# Keeps the supervised-target indices stable by re-mapping
|
||||
@@ -428,6 +563,11 @@ class PI052TextTokenizerStep(ProcessorStep):
|
||||
sample_idx=sample_idx,
|
||||
)
|
||||
|
||||
# Rewrite bbox / keypoint VQA target answers from JSON to
|
||||
# PaliGemma <loc> text — done before stripping so the image
|
||||
# block (camera frame) is still available to normalize against.
|
||||
messages = _messages_vqa_to_loc(messages, target_indices, image_shapes)
|
||||
|
||||
# Flatten ``say`` tool calls into ``<say>...</say>`` text before
|
||||
# stripping, so the spoken reply is actually tokenized and
|
||||
# supervised (PaliGemma's flat prompt has no structured calls).
|
||||
|
||||
@@ -37,6 +37,7 @@ from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
@@ -50,6 +51,14 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
_IMAGE_PREFIX = "observation.images."
|
||||
|
||||
# PaliGemma detection / pointing vocabulary. PI052 trains spatial VQA
|
||||
# answers in this native ``<locNNNN>`` format (index in [0, 1023],
|
||||
# normalized to the image axis) instead of pixel-coordinate JSON, so the
|
||||
# answer string the runtime parses can be e.g.
|
||||
# ``<loc0512><loc0301> blue cube`` (point) or
|
||||
# ``<loc0100><loc0080><loc0400><loc0360> blue cube`` (box).
|
||||
_LOC_RE = re.compile(r"<loc(\d{1,4})>")
|
||||
|
||||
# Iteration order for shape matching — most specific keys first so an
|
||||
# answer is classified deterministically.
|
||||
_SHAPE_ORDER = ("bbox", "keypoint", "count", "attribute", "spatial")
|
||||
@@ -115,16 +124,74 @@ def prompt_camera_choice(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _loc_to_norm(idx: int) -> float:
|
||||
"""PaliGemma ``<locNNNN>`` index → normalized [0, 1] axis coordinate."""
|
||||
return max(0.0, min(1023.0, float(idx))) / 1023.0
|
||||
|
||||
|
||||
def parse_loc_answer(answer: str) -> dict | None:
|
||||
"""Parse a PaliGemma ``<loc>``-format spatial VQA answer.
|
||||
|
||||
PI052 trains spatial answers in PaliGemma's native detection
|
||||
vocabulary: a point is ``<locY><locX> label``, a box is
|
||||
``<locY0><locX0><locY1><locX1> label``, and multiple boxes are joined
|
||||
by `` ; ``. Coordinates come back *normalized* ([0, 1]); the overlay
|
||||
denormalizes them against the chosen camera frame's pixel size.
|
||||
|
||||
Returns ``{"kind", "payload", "normalized": True}`` on success
|
||||
(``payload`` mirrors the JSON shapes so the overlay code is shared),
|
||||
or ``None`` when the answer carries no ``<loc>`` tokens.
|
||||
"""
|
||||
if not answer or "<loc" not in answer:
|
||||
return None
|
||||
segments = [seg for seg in answer.split(";") if "<loc" in seg]
|
||||
points: list[tuple[float, float, str]] = []
|
||||
boxes: list[tuple[float, float, float, float, str]] = []
|
||||
for seg in segments:
|
||||
locs = [int(m) for m in _LOC_RE.findall(seg)]
|
||||
label = _LOC_RE.sub("", seg).strip()
|
||||
if len(locs) == 2:
|
||||
y, x = (_loc_to_norm(v) for v in locs[:2])
|
||||
points.append((x, y, label))
|
||||
elif len(locs) >= 4:
|
||||
y1, x1, y2, x2 = (_loc_to_norm(v) for v in locs[:4])
|
||||
boxes.append((x1, y1, x2, y2, label))
|
||||
if boxes:
|
||||
detections = [
|
||||
{"label": lbl, "bbox_format": "xyxy", "bbox": [x1, y1, x2, y2]}
|
||||
for (x1, y1, x2, y2, lbl) in boxes
|
||||
]
|
||||
return {"kind": "bbox", "payload": {"detections": detections}, "normalized": True}
|
||||
if len(points) == 1:
|
||||
x, y, lbl = points[0]
|
||||
return {
|
||||
"kind": "keypoint",
|
||||
"payload": {"label": lbl, "point_format": "xy", "point": [x, y]},
|
||||
"normalized": True,
|
||||
}
|
||||
if points: # several bare points → treat as detections-as-points
|
||||
detections = [
|
||||
{"label": lbl, "bbox_format": "xyxy", "bbox": [x, y, x, y]} for (x, y, lbl) in points
|
||||
]
|
||||
return {"kind": "bbox", "payload": {"detections": detections}, "normalized": True}
|
||||
return None
|
||||
|
||||
|
||||
def parse_vqa_answer(answer: str) -> dict | None:
|
||||
"""Parse a VQA answer string into ``{"kind", "payload"}``.
|
||||
|
||||
``kind`` is one of the ``VQA_ANSWER_SHAPES`` names (``bbox``,
|
||||
``keypoint``, ``count``, ``attribute``, ``spatial``) or ``"unknown"``
|
||||
when the JSON doesn't match any known shape. Returns ``None`` when
|
||||
the answer is not parseable JSON / not a JSON object.
|
||||
when the JSON doesn't match any known shape. PaliGemma ``<loc>``
|
||||
spatial answers are detected first (PI052 trains them in that native
|
||||
format). Returns ``None`` when the answer is neither ``<loc>`` text
|
||||
nor a parseable JSON object.
|
||||
"""
|
||||
if not answer or not answer.strip():
|
||||
return None
|
||||
loc_parsed = parse_loc_answer(answer)
|
||||
if loc_parsed is not None:
|
||||
return loc_parsed
|
||||
try:
|
||||
payload = json.loads(answer)
|
||||
except (ValueError, TypeError):
|
||||
@@ -189,7 +256,9 @@ def draw_vqa_overlay(image: Any, parsed: dict) -> Any:
|
||||
"""Draw ``bbox`` / ``keypoint`` answers onto a copy of ``image``.
|
||||
|
||||
Non-spatial answers (``count`` / ``attribute`` / ``spatial`` /
|
||||
``unknown``) are returned as an unmodified copy.
|
||||
``unknown``) are returned as an unmodified copy. When ``parsed`` has
|
||||
``normalized=True`` (PaliGemma ``<loc>`` answers) the [0, 1]
|
||||
coordinates are scaled to the image's pixel size.
|
||||
"""
|
||||
from PIL import ImageDraw # noqa: PLC0415
|
||||
|
||||
@@ -197,6 +266,8 @@ def draw_vqa_overlay(image: Any, parsed: dict) -> Any:
|
||||
kind = parsed.get("kind")
|
||||
payload = parsed.get("payload") or {}
|
||||
draw = ImageDraw.Draw(img)
|
||||
w, h = img.size
|
||||
sx, sy = (w, h) if parsed.get("normalized") else (1, 1)
|
||||
|
||||
if kind == "bbox":
|
||||
for det in payload.get("detections") or []:
|
||||
@@ -209,6 +280,8 @@ def draw_vqa_overlay(image: Any, parsed: dict) -> Any:
|
||||
x1, y1, x2, y2 = (float(v) for v in box)
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
x1, x2 = x1 * sx, x2 * sx
|
||||
y1, y2 = y1 * sy, y2 * sy
|
||||
draw.rectangle([x1, y1, x2, y2], outline=_BBOX_COLOR, width=3)
|
||||
label = str(det.get("label", "")).strip()
|
||||
if label:
|
||||
@@ -217,7 +290,7 @@ def draw_vqa_overlay(image: Any, parsed: dict) -> Any:
|
||||
point = payload.get("point")
|
||||
if isinstance(point, list | tuple) and len(point) == 2:
|
||||
try:
|
||||
x, y = float(point[0]), float(point[1])
|
||||
x, y = float(point[0]) * sx, float(point[1]) * sy
|
||||
except (TypeError, ValueError):
|
||||
return img
|
||||
r = 6
|
||||
|
||||
@@ -0,0 +1,123 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Training-side conversion of VQA answers to PaliGemma ``<loc>`` text.
|
||||
|
||||
PI052 trains spatial VQA answers (``bbox`` / ``keypoint``) in
|
||||
PaliGemma's native ``<locNNNN>`` detection vocabulary so the LM head
|
||||
reuses the detection prior instead of fighting it (the ``<loc>``-salad
|
||||
bug). The dataset stays backbone-agnostic JSON; the conversion lives in
|
||||
PI052's tokenizer. These tests pin the JSON → ``<loc>`` rewrite.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from lerobot.policies.pi052.text_processor_pi052 import ( # noqa: E402
|
||||
_camera_image_shapes,
|
||||
_loc_token,
|
||||
_messages_vqa_to_loc,
|
||||
_vqa_answer_to_loc,
|
||||
)
|
||||
|
||||
|
||||
class _FakeTensor:
|
||||
def __init__(self, shape):
|
||||
self.shape = shape
|
||||
|
||||
|
||||
def test_camera_image_shapes_extracts_hw_from_image_keys():
|
||||
obs = {
|
||||
"observation.images.top": _FakeTensor((1, 3, 240, 320)),
|
||||
"observation.images.wrist": _FakeTensor((3, 480, 640)),
|
||||
"observation.state": _FakeTensor((1, 7)),
|
||||
"task": "x",
|
||||
}
|
||||
assert _camera_image_shapes(obs) == {
|
||||
"observation.images.top": (240, 320),
|
||||
"observation.images.wrist": (480, 640),
|
||||
}
|
||||
|
||||
|
||||
def test_camera_image_shapes_handles_empty():
|
||||
assert _camera_image_shapes({}) == {}
|
||||
assert _camera_image_shapes(None) == {}
|
||||
|
||||
|
||||
def test_loc_token_normalizes_and_clamps():
|
||||
assert _loc_token(0, 100) == "<loc0000>"
|
||||
assert _loc_token(100, 100) == "<loc1023>"
|
||||
assert _loc_token(50, 100) == f"<loc{round(50 / 100 * 1023):04d}>"
|
||||
# out-of-range coordinates clamp into [0, 1023]
|
||||
assert _loc_token(999, 100) == "<loc1023>"
|
||||
assert _loc_token(-5, 100) == "<loc0000>"
|
||||
|
||||
|
||||
def test_vqa_answer_to_loc_keypoint():
|
||||
answer = {"label": "blue cube", "point_format": "xy", "point": [160, 120]}
|
||||
# height=240, width=320 → y=120/240=0.5, x=160/320=0.5
|
||||
out = _vqa_answer_to_loc(answer, height=240, width=320)
|
||||
assert out == "<loc0512><loc0512> blue cube"
|
||||
|
||||
|
||||
def test_vqa_answer_to_loc_bbox():
|
||||
answer = {
|
||||
"detections": [
|
||||
{"label": "cube", "bbox_format": "xyxy", "bbox": [0, 0, 320, 240]},
|
||||
]
|
||||
}
|
||||
out = _vqa_answer_to_loc(answer, height=240, width=320)
|
||||
assert out == "<loc0000><loc0000><loc1023><loc1023> cube"
|
||||
|
||||
|
||||
def test_vqa_answer_to_loc_returns_none_for_non_spatial():
|
||||
assert _vqa_answer_to_loc({"label": "cubes", "count": 2}, 240, 320) is None
|
||||
assert _vqa_answer_to_loc({"weird": "payload"}, 240, 320) is None
|
||||
|
||||
|
||||
def test_messages_vqa_to_loc_rewrites_target_turn():
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "feature": "observation.images.top"},
|
||||
{"type": "text", "text": "where is the cube?"},
|
||||
],
|
||||
},
|
||||
{"role": "assistant", "content": '{"label": "cube", "point_format": "xy", "point": [160, 120]}'},
|
||||
]
|
||||
shapes = {"observation.images.top": (240, 320)}
|
||||
out = _messages_vqa_to_loc(messages, target_indices=[1], image_shapes=shapes)
|
||||
assert out[1]["content"] == "<loc0512><loc0512> cube"
|
||||
# input messages are not mutated
|
||||
assert messages[1]["content"].startswith("{")
|
||||
|
||||
|
||||
def test_messages_vqa_to_loc_leaves_plain_text_targets_untouched():
|
||||
messages = [
|
||||
{"role": "user", "content": [{"type": "image", "feature": "observation.images.top"}]},
|
||||
{"role": "assistant", "content": "pick up the cube"},
|
||||
]
|
||||
shapes = {"observation.images.top": (240, 320)}
|
||||
out = _messages_vqa_to_loc(messages, target_indices=[1], image_shapes=shapes)
|
||||
assert out[1]["content"] == "pick up the cube"
|
||||
|
||||
|
||||
def test_messages_vqa_to_loc_noop_without_shapes():
|
||||
messages = [{"role": "assistant", "content": '{"label": "c", "point_format": "xy", "point": [1, 2]}'}]
|
||||
assert _messages_vqa_to_loc(messages, [0], None) is messages
|
||||
assert _messages_vqa_to_loc(messages, [0], {}) is messages
|
||||
@@ -177,3 +177,52 @@ def test_observation_image_to_pil_from_batched_float_array():
|
||||
pil = observation_image_to_pil(arr)
|
||||
assert pil.size == (32, 24)
|
||||
assert pil.mode == "RGB"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PaliGemma <loc>-format answers (PI052 trains spatial VQA in this vocab)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_parse_loc_keypoint_answer():
|
||||
# <locY><locX> label — y=512/1023≈0.5, x=256/1023≈0.25
|
||||
parsed = parse_vqa_answer("<loc0512><loc0256> blue cube")
|
||||
assert parsed["kind"] == "keypoint"
|
||||
assert parsed["normalized"] is True
|
||||
x, y = parsed["payload"]["point"]
|
||||
assert 0.24 < x < 0.26
|
||||
assert 0.49 < y < 0.51
|
||||
assert parsed["payload"]["label"] == "blue cube"
|
||||
assert answer_has_overlay(parsed)
|
||||
|
||||
|
||||
def test_parse_loc_bbox_answer():
|
||||
# <locY0><locX0><locY1><locX1> label
|
||||
parsed = parse_vqa_answer("<loc0100><loc0080><loc0400><loc0360> yellow cube")
|
||||
assert parsed["kind"] == "bbox"
|
||||
assert parsed["normalized"] is True
|
||||
det = parsed["payload"]["detections"][0]
|
||||
x1, y1, x2, y2 = det["bbox"]
|
||||
assert x1 < x2 and y1 < y2
|
||||
assert det["label"] == "yellow cube"
|
||||
assert answer_has_overlay(parsed)
|
||||
|
||||
|
||||
def test_parse_loc_multiple_boxes():
|
||||
answer = "<loc0100><loc0080><loc0400><loc0360> cube ; <loc0200><loc0500><loc0600><loc0900> box"
|
||||
parsed = parse_vqa_answer(answer)
|
||||
assert parsed["kind"] == "bbox"
|
||||
assert len(parsed["payload"]["detections"]) == 2
|
||||
|
||||
|
||||
def test_parse_loc_takes_precedence_over_json():
|
||||
# An answer with <loc> tokens is parsed as loc even if JSON-ish.
|
||||
assert parse_vqa_answer('{"x": <loc0001><loc0002>}')["normalized"] is True
|
||||
|
||||
|
||||
def test_draw_loc_overlay_denormalizes_to_pixels():
|
||||
img = _blank((200, 100))
|
||||
parsed = parse_vqa_answer("<loc0511><loc0511> cube") # ~centre
|
||||
out = draw_vqa_overlay(img, parsed)
|
||||
assert out.size == img.size
|
||||
assert out.tobytes() != img.tobytes()
|
||||
|
||||
Reference in New Issue
Block a user