mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 11:09:59 +00:00
feat(pi052): train VQA spatial answers in PaliGemma <loc> format
Spatial VQA answers (bbox / keypoint) were trained as pixel-coordinate JSON, which fights PaliGemma's detection prior and leaks <loc>-token salad at inference. Convert them to PaliGemma's native <locNNNN> vocabulary instead so the LM head reuses that prior. Training side (text_processor_pi052.py): a target turn whose content parses as a bbox/keypoint answer is rewritten to <loc> text, using the camera frame's native (H, W) from the observation and the preceding image block. Non-spatial answers, subtask/memory targets and SmolVLA2 keep their JSON form — the dataset stays backbone-agnostic. Runtime side (smolvla2/inference/vqa.py): parse_vqa_answer detects <loc> answers (2 locs -> keypoint, 4 -> bbox), returning normalized [0,1] coords with a normalized flag; draw_vqa_overlay denormalizes against the chosen camera frame's pixel size. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,123 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Training-side conversion of VQA answers to PaliGemma ``<loc>`` text.
|
||||
|
||||
PI052 trains spatial VQA answers (``bbox`` / ``keypoint``) in
|
||||
PaliGemma's native ``<locNNNN>`` detection vocabulary so the LM head
|
||||
reuses the detection prior instead of fighting it (the ``<loc>``-salad
|
||||
bug). The dataset stays backbone-agnostic JSON; the conversion lives in
|
||||
PI052's tokenizer. These tests pin the JSON → ``<loc>`` rewrite.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from lerobot.policies.pi052.text_processor_pi052 import ( # noqa: E402
|
||||
_camera_image_shapes,
|
||||
_loc_token,
|
||||
_messages_vqa_to_loc,
|
||||
_vqa_answer_to_loc,
|
||||
)
|
||||
|
||||
|
||||
class _FakeTensor:
|
||||
def __init__(self, shape):
|
||||
self.shape = shape
|
||||
|
||||
|
||||
def test_camera_image_shapes_extracts_hw_from_image_keys():
|
||||
obs = {
|
||||
"observation.images.top": _FakeTensor((1, 3, 240, 320)),
|
||||
"observation.images.wrist": _FakeTensor((3, 480, 640)),
|
||||
"observation.state": _FakeTensor((1, 7)),
|
||||
"task": "x",
|
||||
}
|
||||
assert _camera_image_shapes(obs) == {
|
||||
"observation.images.top": (240, 320),
|
||||
"observation.images.wrist": (480, 640),
|
||||
}
|
||||
|
||||
|
||||
def test_camera_image_shapes_handles_empty():
|
||||
assert _camera_image_shapes({}) == {}
|
||||
assert _camera_image_shapes(None) == {}
|
||||
|
||||
|
||||
def test_loc_token_normalizes_and_clamps():
|
||||
assert _loc_token(0, 100) == "<loc0000>"
|
||||
assert _loc_token(100, 100) == "<loc1023>"
|
||||
assert _loc_token(50, 100) == f"<loc{round(50 / 100 * 1023):04d}>"
|
||||
# out-of-range coordinates clamp into [0, 1023]
|
||||
assert _loc_token(999, 100) == "<loc1023>"
|
||||
assert _loc_token(-5, 100) == "<loc0000>"
|
||||
|
||||
|
||||
def test_vqa_answer_to_loc_keypoint():
|
||||
answer = {"label": "blue cube", "point_format": "xy", "point": [160, 120]}
|
||||
# height=240, width=320 → y=120/240=0.5, x=160/320=0.5
|
||||
out = _vqa_answer_to_loc(answer, height=240, width=320)
|
||||
assert out == "<loc0512><loc0512> blue cube"
|
||||
|
||||
|
||||
def test_vqa_answer_to_loc_bbox():
|
||||
answer = {
|
||||
"detections": [
|
||||
{"label": "cube", "bbox_format": "xyxy", "bbox": [0, 0, 320, 240]},
|
||||
]
|
||||
}
|
||||
out = _vqa_answer_to_loc(answer, height=240, width=320)
|
||||
assert out == "<loc0000><loc0000><loc1023><loc1023> cube"
|
||||
|
||||
|
||||
def test_vqa_answer_to_loc_returns_none_for_non_spatial():
|
||||
assert _vqa_answer_to_loc({"label": "cubes", "count": 2}, 240, 320) is None
|
||||
assert _vqa_answer_to_loc({"weird": "payload"}, 240, 320) is None
|
||||
|
||||
|
||||
def test_messages_vqa_to_loc_rewrites_target_turn():
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "feature": "observation.images.top"},
|
||||
{"type": "text", "text": "where is the cube?"},
|
||||
],
|
||||
},
|
||||
{"role": "assistant", "content": '{"label": "cube", "point_format": "xy", "point": [160, 120]}'},
|
||||
]
|
||||
shapes = {"observation.images.top": (240, 320)}
|
||||
out = _messages_vqa_to_loc(messages, target_indices=[1], image_shapes=shapes)
|
||||
assert out[1]["content"] == "<loc0512><loc0512> cube"
|
||||
# input messages are not mutated
|
||||
assert messages[1]["content"].startswith("{")
|
||||
|
||||
|
||||
def test_messages_vqa_to_loc_leaves_plain_text_targets_untouched():
|
||||
messages = [
|
||||
{"role": "user", "content": [{"type": "image", "feature": "observation.images.top"}]},
|
||||
{"role": "assistant", "content": "pick up the cube"},
|
||||
]
|
||||
shapes = {"observation.images.top": (240, 320)}
|
||||
out = _messages_vqa_to_loc(messages, target_indices=[1], image_shapes=shapes)
|
||||
assert out[1]["content"] == "pick up the cube"
|
||||
|
||||
|
||||
def test_messages_vqa_to_loc_noop_without_shapes():
|
||||
messages = [{"role": "assistant", "content": '{"label": "c", "point_format": "xy", "point": [1, 2]}'}]
|
||||
assert _messages_vqa_to_loc(messages, [0], None) is messages
|
||||
assert _messages_vqa_to_loc(messages, [0], {}) is messages
|
||||
@@ -177,3 +177,52 @@ def test_observation_image_to_pil_from_batched_float_array():
|
||||
pil = observation_image_to_pil(arr)
|
||||
assert pil.size == (32, 24)
|
||||
assert pil.mode == "RGB"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PaliGemma <loc>-format answers (PI052 trains spatial VQA in this vocab)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_parse_loc_keypoint_answer():
|
||||
# <locY><locX> label — y=512/1023≈0.5, x=256/1023≈0.25
|
||||
parsed = parse_vqa_answer("<loc0512><loc0256> blue cube")
|
||||
assert parsed["kind"] == "keypoint"
|
||||
assert parsed["normalized"] is True
|
||||
x, y = parsed["payload"]["point"]
|
||||
assert 0.24 < x < 0.26
|
||||
assert 0.49 < y < 0.51
|
||||
assert parsed["payload"]["label"] == "blue cube"
|
||||
assert answer_has_overlay(parsed)
|
||||
|
||||
|
||||
def test_parse_loc_bbox_answer():
|
||||
# <locY0><locX0><locY1><locX1> label
|
||||
parsed = parse_vqa_answer("<loc0100><loc0080><loc0400><loc0360> yellow cube")
|
||||
assert parsed["kind"] == "bbox"
|
||||
assert parsed["normalized"] is True
|
||||
det = parsed["payload"]["detections"][0]
|
||||
x1, y1, x2, y2 = det["bbox"]
|
||||
assert x1 < x2 and y1 < y2
|
||||
assert det["label"] == "yellow cube"
|
||||
assert answer_has_overlay(parsed)
|
||||
|
||||
|
||||
def test_parse_loc_multiple_boxes():
|
||||
answer = "<loc0100><loc0080><loc0400><loc0360> cube ; <loc0200><loc0500><loc0600><loc0900> box"
|
||||
parsed = parse_vqa_answer(answer)
|
||||
assert parsed["kind"] == "bbox"
|
||||
assert len(parsed["payload"]["detections"]) == 2
|
||||
|
||||
|
||||
def test_parse_loc_takes_precedence_over_json():
|
||||
# An answer with <loc> tokens is parsed as loc even if JSON-ish.
|
||||
assert parse_vqa_answer('{"x": <loc0001><loc0002>}')["normalized"] is True
|
||||
|
||||
|
||||
def test_draw_loc_overlay_denormalizes_to_pixels():
|
||||
img = _blank((200, 100))
|
||||
parsed = parse_vqa_answer("<loc0511><loc0511> cube") # ~centre
|
||||
out = draw_vqa_overlay(img, parsed)
|
||||
assert out.size == img.size
|
||||
assert out.tobytes() != img.tobytes()
|
||||
|
||||
Reference in New Issue
Block a user