feat(pi052): train VQA spatial answers in PaliGemma <loc> format

Spatial VQA answers (bbox / keypoint) were trained as pixel-coordinate
JSON, which fights PaliGemma's detection prior and leaks <loc>-token
salad at inference. Convert them to PaliGemma's native <locNNNN>
vocabulary instead so the LM head reuses that prior.

Training side (text_processor_pi052.py): a target turn whose content
parses as a bbox/keypoint answer is rewritten to <loc> text, using the
camera frame's native (H, W) from the observation and the preceding
image block. Non-spatial answers, subtask/memory targets and SmolVLA2
keep their JSON form — the dataset stays backbone-agnostic.

Runtime side (smolvla2/inference/vqa.py): parse_vqa_answer detects
<loc> answers (2 locs -> keypoint, 4 -> bbox), returning normalized
[0,1] coords with a normalized flag; draw_vqa_overlay denormalizes
against the chosen camera frame's pixel size.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-19 20:23:46 +02:00
parent e425dfd624
commit c026aed8f8
5 changed files with 1917 additions and 141 deletions
@@ -36,6 +36,7 @@ Outputs:
from __future__ import annotations
import json
import logging
import os
from dataclasses import dataclass
@@ -234,6 +235,134 @@ def _sample_indices(value: Any, batch_size: int) -> list[int | None]:
return [int(value)] * batch_size
# ---------------------------------------------------------------------------
# VQA spatial answers → PaliGemma <loc> format (PI052 only)
#
# PaliGemma is pre-trained on detection / pointing with a ``<locNNNN>``
# vocabulary (normalized [0, 1023]). The recipe's bbox / keypoint VQA
# answers are stored as JSON with *pixel* coordinates. Training those in
# ``<loc>`` form leverages PaliGemma's prior instead of fighting it (the
# ``<loc>``-token salad). The conversion lives here — not in the dataset
# — so the dataset stays backbone-agnostic (SmolVLA2 keeps the JSON).
# ---------------------------------------------------------------------------
def _camera_image_shapes(observation: dict[str, Any]) -> dict[str, tuple[int, int]]:
"""Map each ``observation.images.*`` key to its native ``(height, width)``.
VQA pixel coordinates are relative to the camera frame's native
resolution. PI052's input pipeline applies no spatial resize before
this step, so the observation image tensors are still at that
resolution — the correct reference for normalizing to ``<loc>``.
"""
shapes: dict[str, tuple[int, int]] = {}
for key, value in (observation or {}).items():
if not (isinstance(key, str) and key.startswith("observation.images.")):
continue
shape = getattr(value, "shape", None)
if shape is None or len(shape) < 2:
continue
shapes[key] = (int(shape[-2]), int(shape[-1])) # (H, W); handles (B,C,H,W)/(C,H,W)
return shapes
def _loc_token(coord: float, dim: int) -> str:
"""PaliGemma ``<locNNNN>`` for pixel ``coord`` on an axis of size ``dim``."""
idx = round(float(coord) / dim * 1023) if dim > 0 else 0
return f"<loc{max(0, min(1023, idx)):04d}>"
def _vqa_answer_to_loc(answer: dict[str, Any], height: int, width: int) -> str | None:
"""Convert a bbox / keypoint VQA answer dict to PaliGemma ``<loc>`` text.
PaliGemma convention: a point is ``<locY><locX> label``; a box is
``<locY0><locX0><locY1><locX1> label`` (y before x, each index in
[0, 1023]). Returns ``None`` for non-spatial answers (count /
attribute / spatial-relation) — those keep their JSON form.
"""
point = answer.get("point")
if isinstance(point, list | tuple) and len(point) == 2 and "point_format" in answer:
try:
x, y = float(point[0]), float(point[1])
except (TypeError, ValueError):
return None
label = str(answer.get("label", "")).strip()
return f"{_loc_token(y, height)}{_loc_token(x, width)} {label}".strip()
detections = answer.get("detections")
if isinstance(detections, list) and detections:
parts: list[str] = []
for det in detections:
if not isinstance(det, dict):
continue
box = det.get("bbox")
if not (isinstance(box, list | tuple) and len(box) == 4):
continue
try:
x1, y1, x2, y2 = (float(v) for v in box)
except (TypeError, ValueError):
continue
label = str(det.get("label", "")).strip()
toks = (
f"{_loc_token(y1, height)}{_loc_token(x1, width)}"
f"{_loc_token(y2, height)}{_loc_token(x2, width)}"
)
parts.append(f"{toks} {label}".strip())
return " ; ".join(parts) if parts else None
return None
def _preceding_image_feature(messages: list[dict[str, Any]], idx: int) -> str | None:
"""Camera ``feature`` of the nearest image block at or before ``idx``."""
for j in range(min(idx, len(messages) - 1), -1, -1):
content = messages[j].get("content")
if not isinstance(content, list):
continue
for block in content:
if isinstance(block, dict) and block.get("type") == "image":
feature = block.get("feature")
if isinstance(feature, str):
return feature
return None
def _messages_vqa_to_loc(
messages: list[dict[str, Any]],
target_indices: list[int],
image_shapes: dict[str, tuple[int, int]] | None,
) -> list[dict[str, Any]]:
"""Rewrite bbox / keypoint VQA *target* answers from JSON to ``<loc>`` text.
Each target turn whose content parses as a spatial VQA answer is
converted, using the camera frame found from the preceding image
block. Non-spatial answers, subtask / memory targets (plain text →
not JSON), and turns with no matching image shape are left untouched.
"""
if not image_shapes or not target_indices:
return messages
out = list(messages)
for idx in target_indices:
if not (0 <= idx < len(out)):
continue
content = out[idx].get("content")
if not isinstance(content, str) or not content.strip():
continue
try:
answer = json.loads(content)
except (ValueError, TypeError):
continue # subtask / memory targets are plain text — skip
if not isinstance(answer, dict):
continue
feature = _preceding_image_feature(out, idx)
if feature is None or feature not in image_shapes:
continue
h, w = image_shapes[feature]
loc_text = _vqa_answer_to_loc(answer, h, w)
if loc_text is not None:
out[idx] = {**out[idx], "content": loc_text}
return out
def _format_messages(
messages: list[dict[str, Any]],
target_indices: list[int] | None = None,
@@ -329,6 +458,9 @@ class PI052TextTokenizerStep(ProcessorStep):
return transition
tokenizer = self._ensure_tokenizer()
# Native camera resolutions — the reference frame for converting
# VQA pixel coordinates to PaliGemma <loc> tokens.
image_shapes = _camera_image_shapes(transition.get(TransitionKey.OBSERVATION) or {})
if _is_batched_messages(messages):
indices_iter = _sample_indices(complementary.get("index"), len(messages))
encoded = [
@@ -339,6 +471,7 @@ class PI052TextTokenizerStep(ProcessorStep):
list(tgt_indices),
complementary,
sample_idx=int(s_idx) if s_idx is not None else None,
image_shapes=image_shapes,
)
for msg, streams, tgt_indices, s_idx in zip(
messages,
@@ -358,6 +491,7 @@ class PI052TextTokenizerStep(ProcessorStep):
list(complementary.get("target_message_indices") or []),
complementary,
sample_idx=sample_idx,
image_shapes=image_shapes,
)
]
@@ -411,6 +545,7 @@ class PI052TextTokenizerStep(ProcessorStep):
target_indices: list[int],
complementary: dict[str, Any],
sample_idx: int | None = None,
image_shapes: dict[str, tuple[int, int]] | None = None,
) -> tuple[Tensor, Tensor, Tensor, Tensor, str]:
# Optional: drop non-target messages per the dropout config.
# Keeps the supervised-target indices stable by re-mapping
@@ -428,6 +563,11 @@ class PI052TextTokenizerStep(ProcessorStep):
sample_idx=sample_idx,
)
# Rewrite bbox / keypoint VQA target answers from JSON to
# PaliGemma <loc> text — done before stripping so the image
# block (camera frame) is still available to normalize against.
messages = _messages_vqa_to_loc(messages, target_indices, image_shapes)
# Flatten ``say`` tool calls into ``<say>...</say>`` text before
# stripping, so the spoken reply is actually tokenized and
# supervised (PaliGemma's flat prompt has no structured calls).
+77 -4
View File
@@ -37,6 +37,7 @@ from __future__ import annotations
import json
import logging
import os
import re
import subprocess
import sys
import time
@@ -50,6 +51,14 @@ logger = logging.getLogger(__name__)
_IMAGE_PREFIX = "observation.images."
# PaliGemma detection / pointing vocabulary. PI052 trains spatial VQA
# answers in this native ``<locNNNN>`` format (index in [0, 1023],
# normalized to the image axis) instead of pixel-coordinate JSON, so the
# answer string the runtime parses can be e.g.
# ``<loc0512><loc0301> blue cube`` (point) or
# ``<loc0100><loc0080><loc0400><loc0360> blue cube`` (box).
_LOC_RE = re.compile(r"<loc(\d{1,4})>")
# Iteration order for shape matching — most specific keys first so an
# answer is classified deterministically.
_SHAPE_ORDER = ("bbox", "keypoint", "count", "attribute", "spatial")
@@ -115,16 +124,74 @@ def prompt_camera_choice(
# ---------------------------------------------------------------------------
def _loc_to_norm(idx: int) -> float:
"""PaliGemma ``<locNNNN>`` index → normalized [0, 1] axis coordinate."""
return max(0.0, min(1023.0, float(idx))) / 1023.0
def parse_loc_answer(answer: str) -> dict | None:
"""Parse a PaliGemma ``<loc>``-format spatial VQA answer.
PI052 trains spatial answers in PaliGemma's native detection
vocabulary: a point is ``<locY><locX> label``, a box is
``<locY0><locX0><locY1><locX1> label``, and multiple boxes are joined
by `` ; ``. Coordinates come back *normalized* ([0, 1]); the overlay
denormalizes them against the chosen camera frame's pixel size.
Returns ``{"kind", "payload", "normalized": True}`` on success
(``payload`` mirrors the JSON shapes so the overlay code is shared),
or ``None`` when the answer carries no ``<loc>`` tokens.
"""
if not answer or "<loc" not in answer:
return None
segments = [seg for seg in answer.split(";") if "<loc" in seg]
points: list[tuple[float, float, str]] = []
boxes: list[tuple[float, float, float, float, str]] = []
for seg in segments:
locs = [int(m) for m in _LOC_RE.findall(seg)]
label = _LOC_RE.sub("", seg).strip()
if len(locs) == 2:
y, x = (_loc_to_norm(v) for v in locs[:2])
points.append((x, y, label))
elif len(locs) >= 4:
y1, x1, y2, x2 = (_loc_to_norm(v) for v in locs[:4])
boxes.append((x1, y1, x2, y2, label))
if boxes:
detections = [
{"label": lbl, "bbox_format": "xyxy", "bbox": [x1, y1, x2, y2]}
for (x1, y1, x2, y2, lbl) in boxes
]
return {"kind": "bbox", "payload": {"detections": detections}, "normalized": True}
if len(points) == 1:
x, y, lbl = points[0]
return {
"kind": "keypoint",
"payload": {"label": lbl, "point_format": "xy", "point": [x, y]},
"normalized": True,
}
if points: # several bare points → treat as detections-as-points
detections = [
{"label": lbl, "bbox_format": "xyxy", "bbox": [x, y, x, y]} for (x, y, lbl) in points
]
return {"kind": "bbox", "payload": {"detections": detections}, "normalized": True}
return None
def parse_vqa_answer(answer: str) -> dict | None:
"""Parse a VQA answer string into ``{"kind", "payload"}``.
``kind`` is one of the ``VQA_ANSWER_SHAPES`` names (``bbox``,
``keypoint``, ``count``, ``attribute``, ``spatial``) or ``"unknown"``
when the JSON doesn't match any known shape. Returns ``None`` when
the answer is not parseable JSON / not a JSON object.
when the JSON doesn't match any known shape. PaliGemma ``<loc>``
spatial answers are detected first (PI052 trains them in that native
format). Returns ``None`` when the answer is neither ``<loc>`` text
nor a parseable JSON object.
"""
if not answer or not answer.strip():
return None
loc_parsed = parse_loc_answer(answer)
if loc_parsed is not None:
return loc_parsed
try:
payload = json.loads(answer)
except (ValueError, TypeError):
@@ -189,7 +256,9 @@ def draw_vqa_overlay(image: Any, parsed: dict) -> Any:
"""Draw ``bbox`` / ``keypoint`` answers onto a copy of ``image``.
Non-spatial answers (``count`` / ``attribute`` / ``spatial`` /
``unknown``) are returned as an unmodified copy.
``unknown``) are returned as an unmodified copy. When ``parsed`` has
``normalized=True`` (PaliGemma ``<loc>`` answers) the [0, 1]
coordinates are scaled to the image's pixel size.
"""
from PIL import ImageDraw # noqa: PLC0415
@@ -197,6 +266,8 @@ def draw_vqa_overlay(image: Any, parsed: dict) -> Any:
kind = parsed.get("kind")
payload = parsed.get("payload") or {}
draw = ImageDraw.Draw(img)
w, h = img.size
sx, sy = (w, h) if parsed.get("normalized") else (1, 1)
if kind == "bbox":
for det in payload.get("detections") or []:
@@ -209,6 +280,8 @@ def draw_vqa_overlay(image: Any, parsed: dict) -> Any:
x1, y1, x2, y2 = (float(v) for v in box)
except (TypeError, ValueError):
continue
x1, x2 = x1 * sx, x2 * sx
y1, y2 = y1 * sy, y2 * sy
draw.rectangle([x1, y1, x2, y2], outline=_BBOX_COLOR, width=3)
label = str(det.get("label", "")).strip()
if label:
@@ -217,7 +290,7 @@ def draw_vqa_overlay(image: Any, parsed: dict) -> Any:
point = payload.get("point")
if isinstance(point, list | tuple) and len(point) == 2:
try:
x, y = float(point[0]), float(point[1])
x, y = float(point[0]) * sx, float(point[1]) * sy
except (TypeError, ValueError):
return img
r = 6
+123
View File
@@ -0,0 +1,123 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Training-side conversion of VQA answers to PaliGemma ``<loc>`` text.
PI052 trains spatial VQA answers (``bbox`` / ``keypoint``) in
PaliGemma's native ``<locNNNN>`` detection vocabulary so the LM head
reuses the detection prior instead of fighting it (the ``<loc>``-salad
bug). The dataset stays backbone-agnostic JSON; the conversion lives in
PI052's tokenizer. These tests pin the JSON → ``<loc>`` rewrite.
"""
import pytest
pytest.importorskip("transformers")
from lerobot.policies.pi052.text_processor_pi052 import ( # noqa: E402
_camera_image_shapes,
_loc_token,
_messages_vqa_to_loc,
_vqa_answer_to_loc,
)
class _FakeTensor:
def __init__(self, shape):
self.shape = shape
def test_camera_image_shapes_extracts_hw_from_image_keys():
obs = {
"observation.images.top": _FakeTensor((1, 3, 240, 320)),
"observation.images.wrist": _FakeTensor((3, 480, 640)),
"observation.state": _FakeTensor((1, 7)),
"task": "x",
}
assert _camera_image_shapes(obs) == {
"observation.images.top": (240, 320),
"observation.images.wrist": (480, 640),
}
def test_camera_image_shapes_handles_empty():
assert _camera_image_shapes({}) == {}
assert _camera_image_shapes(None) == {}
def test_loc_token_normalizes_and_clamps():
assert _loc_token(0, 100) == "<loc0000>"
assert _loc_token(100, 100) == "<loc1023>"
assert _loc_token(50, 100) == f"<loc{round(50 / 100 * 1023):04d}>"
# out-of-range coordinates clamp into [0, 1023]
assert _loc_token(999, 100) == "<loc1023>"
assert _loc_token(-5, 100) == "<loc0000>"
def test_vqa_answer_to_loc_keypoint():
answer = {"label": "blue cube", "point_format": "xy", "point": [160, 120]}
# height=240, width=320 → y=120/240=0.5, x=160/320=0.5
out = _vqa_answer_to_loc(answer, height=240, width=320)
assert out == "<loc0512><loc0512> blue cube"
def test_vqa_answer_to_loc_bbox():
answer = {
"detections": [
{"label": "cube", "bbox_format": "xyxy", "bbox": [0, 0, 320, 240]},
]
}
out = _vqa_answer_to_loc(answer, height=240, width=320)
assert out == "<loc0000><loc0000><loc1023><loc1023> cube"
def test_vqa_answer_to_loc_returns_none_for_non_spatial():
assert _vqa_answer_to_loc({"label": "cubes", "count": 2}, 240, 320) is None
assert _vqa_answer_to_loc({"weird": "payload"}, 240, 320) is None
def test_messages_vqa_to_loc_rewrites_target_turn():
messages = [
{
"role": "user",
"content": [
{"type": "image", "feature": "observation.images.top"},
{"type": "text", "text": "where is the cube?"},
],
},
{"role": "assistant", "content": '{"label": "cube", "point_format": "xy", "point": [160, 120]}'},
]
shapes = {"observation.images.top": (240, 320)}
out = _messages_vqa_to_loc(messages, target_indices=[1], image_shapes=shapes)
assert out[1]["content"] == "<loc0512><loc0512> cube"
# input messages are not mutated
assert messages[1]["content"].startswith("{")
def test_messages_vqa_to_loc_leaves_plain_text_targets_untouched():
messages = [
{"role": "user", "content": [{"type": "image", "feature": "observation.images.top"}]},
{"role": "assistant", "content": "pick up the cube"},
]
shapes = {"observation.images.top": (240, 320)}
out = _messages_vqa_to_loc(messages, target_indices=[1], image_shapes=shapes)
assert out[1]["content"] == "pick up the cube"
def test_messages_vqa_to_loc_noop_without_shapes():
messages = [{"role": "assistant", "content": '{"label": "c", "point_format": "xy", "point": [1, 2]}'}]
assert _messages_vqa_to_loc(messages, [0], None) is messages
assert _messages_vqa_to_loc(messages, [0], {}) is messages
@@ -177,3 +177,52 @@ def test_observation_image_to_pil_from_batched_float_array():
pil = observation_image_to_pil(arr)
assert pil.size == (32, 24)
assert pil.mode == "RGB"
# ---------------------------------------------------------------------------
# PaliGemma <loc>-format answers (PI052 trains spatial VQA in this vocab)
# ---------------------------------------------------------------------------
def test_parse_loc_keypoint_answer():
# <locY><locX> label — y=512/1023≈0.5, x=256/1023≈0.25
parsed = parse_vqa_answer("<loc0512><loc0256> blue cube")
assert parsed["kind"] == "keypoint"
assert parsed["normalized"] is True
x, y = parsed["payload"]["point"]
assert 0.24 < x < 0.26
assert 0.49 < y < 0.51
assert parsed["payload"]["label"] == "blue cube"
assert answer_has_overlay(parsed)
def test_parse_loc_bbox_answer():
# <locY0><locX0><locY1><locX1> label
parsed = parse_vqa_answer("<loc0100><loc0080><loc0400><loc0360> yellow cube")
assert parsed["kind"] == "bbox"
assert parsed["normalized"] is True
det = parsed["payload"]["detections"][0]
x1, y1, x2, y2 = det["bbox"]
assert x1 < x2 and y1 < y2
assert det["label"] == "yellow cube"
assert answer_has_overlay(parsed)
def test_parse_loc_multiple_boxes():
answer = "<loc0100><loc0080><loc0400><loc0360> cube ; <loc0200><loc0500><loc0600><loc0900> box"
parsed = parse_vqa_answer(answer)
assert parsed["kind"] == "bbox"
assert len(parsed["payload"]["detections"]) == 2
def test_parse_loc_takes_precedence_over_json():
# An answer with <loc> tokens is parsed as loc even if JSON-ish.
assert parse_vqa_answer('{"x": <loc0001><loc0002>}')["normalized"] is True
def test_draw_loc_overlay_denormalizes_to_pixels():
img = _blank((200, 100))
parsed = parse_vqa_answer("<loc0511><loc0511> cube") # ~centre
out = draw_vqa_overlay(img, parsed)
assert out.size == img.size
assert out.tobytes() != img.tobytes()
Generated
+1528 -137
View File
File diff suppressed because it is too large Load Diff